139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
APP_PATH = ROOT / "deploy" / "storyforge-windows-asr-http" / "app.py"
|
|
|
|
|
|
def load_windows_asr_app():
|
|
spec = importlib.util.spec_from_file_location("storyforge_windows_asr_app", APP_PATH)
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert spec and spec.loader
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
class WindowsAsrHttpTests(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self._saved = {key: os.environ.get(key) for key in [
|
|
"WHISPER_LANGUAGE",
|
|
"WHISPER_DEVICE",
|
|
"WHISPER_COMPUTE_TYPE",
|
|
]}
|
|
|
|
def tearDown(self) -> None:
|
|
for key, value in self._saved.items():
|
|
if value is None:
|
|
os.environ.pop(key, None)
|
|
else:
|
|
os.environ[key] = value
|
|
|
|
def test_defaults_prefer_auto_language_and_gpu_first_profiles(self) -> None:
|
|
os.environ.pop("WHISPER_LANGUAGE", None)
|
|
os.environ.pop("WHISPER_DEVICE", None)
|
|
os.environ.pop("WHISPER_COMPUTE_TYPE", None)
|
|
|
|
module = load_windows_asr_app()
|
|
|
|
self.assertIsNone(module.resolve_language())
|
|
self.assertEqual(module.describe_language_mode(), "auto")
|
|
self.assertEqual(
|
|
module.build_runtime_profiles(),
|
|
[("cuda", "int8_float16"), ("cpu", "int8")],
|
|
)
|
|
|
|
client = TestClient(module.app)
|
|
try:
|
|
payload = client.get("/health").json()
|
|
finally:
|
|
client.close()
|
|
|
|
self.assertEqual(payload["language"], "auto")
|
|
self.assertEqual(payload["device"], "auto")
|
|
self.assertEqual(payload["compute_type"], "auto")
|
|
|
|
def test_explicit_runtime_overrides_are_respected(self) -> None:
|
|
os.environ["WHISPER_LANGUAGE"] = "zh"
|
|
os.environ["WHISPER_DEVICE"] = "cpu"
|
|
os.environ["WHISPER_COMPUTE_TYPE"] = "int8"
|
|
|
|
module = load_windows_asr_app()
|
|
|
|
self.assertEqual(module.resolve_language(), "zh")
|
|
self.assertEqual(module.describe_language_mode(), "zh")
|
|
self.assertEqual(module.build_runtime_profiles(), [("cpu", "int8")])
|
|
|
|
def test_windows_cuda_runtime_discovery_finds_nvidia_wheel_bins(self) -> None:
|
|
module = load_windows_asr_app()
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
root = Path(tempdir)
|
|
for rel in [
|
|
"nvidia/cublas/bin",
|
|
"nvidia/cuda_runtime/bin",
|
|
"nvidia/cuda_nvrtc/bin",
|
|
"nvidia/cudnn/bin",
|
|
]:
|
|
(root / rel).mkdir(parents=True, exist_ok=True)
|
|
|
|
dirs = module.find_windows_cuda_runtime_dirs(root)
|
|
|
|
self.assertEqual(
|
|
[path.as_posix() for path in dirs],
|
|
[
|
|
(root / "nvidia/cublas/bin").as_posix(),
|
|
(root / "nvidia/cuda_runtime/bin").as_posix(),
|
|
(root / "nvidia/cuda_nvrtc/bin").as_posix(),
|
|
(root / "nvidia/cudnn/bin").as_posix(),
|
|
],
|
|
)
|
|
|
|
def test_auto_runtime_falls_back_to_cpu_when_cuda_runtime_is_missing(self) -> None:
|
|
os.environ.pop("WHISPER_LANGUAGE", None)
|
|
os.environ.pop("WHISPER_DEVICE", None)
|
|
os.environ.pop("WHISPER_COMPUTE_TYPE", None)
|
|
|
|
module = load_windows_asr_app()
|
|
|
|
class FakeInfo:
|
|
language = "zh"
|
|
language_probability = 0.88
|
|
|
|
class FakeSegment:
|
|
text = "中英混合 hello world"
|
|
|
|
class BrokenGpuModel:
|
|
def transcribe(self, *_args, **_kwargs):
|
|
raise RuntimeError("Library cublas64_12.dll is not found or cannot be loaded")
|
|
|
|
class CpuModel:
|
|
def transcribe(self, *_args, **_kwargs):
|
|
return iter([FakeSegment()]), FakeInfo()
|
|
|
|
models = [BrokenGpuModel(), CpuModel()]
|
|
cache_cleared = []
|
|
|
|
def fake_get_model():
|
|
return models.pop(0)
|
|
|
|
fake_get_model.cache_clear = lambda: cache_cleared.append(True)
|
|
module.get_model = fake_get_model
|
|
module.app.state.runtime_device = "cuda"
|
|
module.app.state.runtime_compute_type = "int8_float16"
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav") as handle:
|
|
payload = module.transcribe_file(Path(handle.name), started=0.0)
|
|
|
|
self.assertTrue(payload["success"])
|
|
self.assertEqual(payload["text"], "中英混合 hello world")
|
|
self.assertEqual(payload["detected_language"], "zh")
|
|
self.assertEqual(module.app.state.runtime_fallback_profile, ("cpu", "int8"))
|
|
self.assertTrue(cache_cleared)
|