Files
storyforge/tests/test_production_baseline.py

356 lines
15 KiB
Python

from __future__ import annotations
import importlib
import json
import os
import sqlite3
import subprocess
import sys
import tempfile
import unittest
from pathlib import Path
from typing import Any
from fastapi.testclient import TestClient
ROOT = Path(__file__).resolve().parents[1]
APP_ROOT = ROOT / "collector-service"
if str(APP_ROOT) not in sys.path:
sys.path.insert(0, str(APP_ROOT))
class ProductionBaselineTests(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.tempdir = tempfile.TemporaryDirectory()
temp_root = Path(cls.tempdir.name)
os.environ["DATA_DIR"] = str(temp_root / "data")
os.environ["DATABASE_PATH"] = str(temp_root / "data" / "storyforge.db")
os.environ["DOWNLOADS_DIR"] = str(temp_root / "downloads")
os.environ["JOBS_DIR"] = str(temp_root / "jobs")
os.environ["MODELS_DIR"] = str(temp_root / "models")
os.environ["ORCHESTRATOR_SHARED_SECRET"] = "test-secret"
os.environ.setdefault("BOOTSTRAP_SUPERADMIN_USERNAME", "")
os.environ.setdefault("BOOTSTRAP_SUPERADMIN_PASSWORD", "")
cls.db_module = importlib.reload(importlib.import_module("app.database"))
cls.core = importlib.reload(importlib.import_module("app.core_main"))
cls.app_main = importlib.reload(importlib.import_module("app.main"))
async def fake_trigger(job_row: dict[str, Any]) -> dict[str, Any]:
cls.core.append_job_event(job_row["id"], "workflow.trigger.requested", {"workflow_key": job_row.get("workflow_key", "")})
cls.core.update_job_state(
job_row["id"],
status="queued",
provider_name="n8n",
provider_task_id="",
result={"n8n_trigger": {"requested": True, "mocked": True}},
)
return cls.core.db.fetch_one("SELECT * FROM jobs WHERE id = ?", (job_row["id"],))
async def fake_call_model(*_args: object, **_kwargs: object) -> str:
return "mock content"
cls.core.trigger_orchestrated_job = fake_trigger
cls.core.call_model = fake_call_model
cls.core.sync_live_recorder_remote_config = lambda: {"ok": True}
cls.core.db.init_schema()
cls.client = TestClient(cls.app_main.app)
@classmethod
def tearDownClass(cls) -> None:
cls.client.close()
cls.tempdir.cleanup()
def setUp(self) -> None:
self._clear_tables()
def _clear_tables(self) -> None:
tables = [
"job_events",
"tenant_usage_ledger",
"tenant_quota_profiles",
"auth_tokens",
"publish_reviews",
"live_recorder_bindings",
"live_recorder_sources",
"jobs",
"content_sources",
"assistant_knowledge_bases",
"assistants",
"knowledge_documents",
"knowledge_bases",
"projects",
"accounts",
"model_profiles",
]
for table in tables:
self.core.db.execute(f"DELETE FROM {table}")
def _seed_context(self, tag: str, *, exhausted: bool = False) -> dict[str, Any]:
now = self.db_module.utc_now()
account_id = f"acct_{tag}"
project_id = f"proj_{tag}"
model_id = f"model_{tag}"
kb_id = f"kb_{tag}"
assistant_id = f"assistant_{tag}"
token = f"token_{tag}"
username = f"user_{tag}"
self.core.db.execute(
"""
INSERT INTO accounts (
id, username, password_hash, password_salt, display_name, role, approval_status,
approved_by, approved_at, preferred_analysis_model_id, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
account_id,
username,
"hash",
"salt",
f"User {tag}",
"super_admin",
"approved",
account_id,
now,
model_id,
now,
now,
),
)
self.core.db.execute(
"""
INSERT INTO projects (id, user_id, name, description, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(project_id, account_id, f"Project {tag}", "", now, now),
)
self.core.db.execute(
"""
INSERT INTO model_profiles (
id, owner_account_id, name, provider, base_url, api_key, model_name,
is_system, is_default, created_at, updated_at
) VALUES (?, NULL, ?, ?, ?, ?, ?, 1, 1, ?, ?)
""",
(model_id, "Default Model", "openai_compat", "http://127.0.0.1:8317/v1", "", "GLM-5", now, now),
)
self.core.db.execute(
"""
INSERT INTO knowledge_bases (id, user_id, project_id, name, description, sync_status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, 'ready', ?, ?)
""",
(kb_id, account_id, project_id, f"KB {tag}", "", now, now),
)
self.core.db.execute(
"""
INSERT INTO assistants (id, user_id, project_id, name, description, system_prompt, generation_goal, config_json, model_profile_id, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, '{}', ?, ?, ?)
""",
(
assistant_id,
account_id,
project_id,
f"Assistant {tag}",
"",
"你是文案助手。",
"生成短视频文案。",
model_id,
now,
now,
),
)
self.core.db.execute("INSERT INTO assistant_knowledge_bases (assistant_id, knowledge_base_id) VALUES (?, ?)", (assistant_id, kb_id))
self.core.db.execute(
"INSERT INTO auth_tokens (token, account_id, created_at) VALUES (?, ?, ?)",
(token, account_id, now),
)
if exhausted:
quota_id = f"quota_{tag}"
self.core.db.execute(
"""
INSERT INTO tenant_quota_profiles (
id, user_id, project_id, monthly_budget_cents, storage_limit_bytes, analysis_quota,
copy_quota, ai_video_quota, real_cut_quota, recorder_quota, enabled, config_json,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, '{}', ?, ?)
""",
(quota_id, account_id, project_id, 9999, 0, 1, 1, 1, 1, 1, now, now),
)
for category in ["analysis", "content_source_sync", "review", "copy", "ai_video", "real_cut", "live_recorder"]:
usage_id = f"usage_{tag}_{category}"
cost_map = {
"analysis": 6,
"content_source_sync": 8,
"review": 1,
"copy": 3,
"ai_video": 30,
"real_cut": 20,
"live_recorder": 2,
}
self.core.db.execute(
"""
INSERT INTO tenant_usage_ledger (
id, user_id, project_id, category, quantity, cost_cents, reference_type, reference_id, details_json, created_at
) VALUES (?, ?, ?, ?, 1, ?, 'seed', ?, '{}', ?)
""",
(usage_id, account_id, project_id, category, cost_map[category], usage_id, now),
)
return {
"account_id": account_id,
"project_id": project_id,
"model_id": model_id,
"kb_id": kb_id,
"assistant_id": assistant_id,
"token": token,
}
def test_database_uses_wal_and_busy_timeout(self) -> None:
conn = self.core.db.connect()
try:
journal_mode_row = conn.execute("PRAGMA journal_mode").fetchone()
busy_timeout_row = conn.execute("PRAGMA busy_timeout").fetchone()
journal_mode = journal_mode_row["journal_mode"] if isinstance(journal_mode_row, dict) else journal_mode_row[0]
if isinstance(busy_timeout_row, dict):
busy_timeout = int(busy_timeout_row.get("timeout") or busy_timeout_row.get("busy_timeout") or next(iter(busy_timeout_row.values())))
else:
busy_timeout = int(busy_timeout_row[0])
finally:
conn.close()
self.assertEqual(str(journal_mode).lower(), "wal")
self.assertGreaterEqual(busy_timeout, 1000)
def test_quota_blocks_production_endpoints(self) -> None:
ctx = self._seed_context("quota", exhausted=True)
headers = {"Authorization": f"Bearer {ctx['token']}"}
blocked_requests = [
("POST", "/v2/explore/text", {"title": "T", "content": "C", "project_id": ctx["project_id"], "knowledge_base_id": ctx["kb_id"], "assistant_id": ctx["assistant_id"], "analysis_model_profile_id": ctx["model_id"]}, None),
("POST", "/v2/explore/video-link", {"video_url": "https://example.com/video", "title": "V", "project_id": ctx["project_id"], "knowledge_base_id": ctx["kb_id"], "assistant_id": ctx["assistant_id"], "analysis_model_profile_id": ctx["model_id"]}, None),
("POST", "/v2/pipelines/content-source-sync", {"project_id": ctx["project_id"]}, None),
("POST", "/v2/reviews", {"project_id": ctx["project_id"], "assistant_id": ctx["assistant_id"], "title": "Review"}, None),
("POST", "/v2/pipelines/real-cut", {"project_id": ctx["project_id"], "title": "Cut"}, None),
("POST", "/v2/pipelines/ai-video", {"project_id": ctx["project_id"], "title": "Video", "brief": "Brief"}, None),
("POST", f"/v2/assistants/{ctx['assistant_id']}/generate", {"brief": "Copy", "project_id": ctx["project_id"], "knowledge_base_ids": [ctx["kb_id"]]}, None),
("POST", "/v2/live-recorder/sources", {"project_id": ctx["project_id"], "source_url": "https://example.com/live", "title": "Live"}, None),
]
for method, path, json_body, files in blocked_requests:
with self.subTest(path=path):
response = self.client.request(method, path, headers=headers, json=json_body, files=files)
self.assertEqual(response.status_code, 403, response.text)
upload_response = self.client.post(
"/v2/explore/upload-video",
headers=headers,
data={
"title": "Upload",
"project_id": ctx["project_id"],
"knowledge_base_id": ctx["kb_id"],
"assistant_id": ctx["assistant_id"],
"analysis_model_profile_id": ctx["model_id"],
},
files={"file": ("clip.mp4", b"clip-bytes", "video/mp4")},
)
self.assertEqual(upload_response.status_code, 403, upload_response.text)
def test_successful_analysis_records_usage_and_retry_endpoints_work(self) -> None:
ctx = self._seed_context("happy", exhausted=False)
headers = {"Authorization": f"Bearer {ctx['token']}"}
text_response = self.client.post(
"/v2/explore/text",
headers=headers,
json={
"title": "Hello",
"content": "Hello StoryForge",
"project_id": ctx["project_id"],
"knowledge_base_id": ctx["kb_id"],
"assistant_id": ctx["assistant_id"],
"analysis_model_profile_id": ctx["model_id"],
},
)
self.assertEqual(text_response.status_code, 200, text_response.text)
text_job = text_response.json()
usage_row = self.core.db.fetch_one(
"SELECT * FROM tenant_usage_ledger WHERE user_id = ? AND project_id = ? AND category = ? ORDER BY created_at DESC LIMIT 1",
(ctx["account_id"], ctx["project_id"], "analysis"),
)
self.assertIsNotNone(usage_row)
self.assertEqual(text_job["status"], "queued")
now = self.db_module.utc_now()
failed_jobs = []
for index in range(2):
job_id = f"job_{index}_{ctx['project_id']}"
self.core.db.execute(
"""
INSERT INTO jobs (
id, user_id, project_id, parent_job_id, assistant_id, knowledge_base_id, content_source_id,
source_type, line_type, workflow_key, orchestrator, provider_name, provider_task_id,
source_url, title, language, status, transcript_text, style_summary, upload_status,
error, artifacts_json, result_json, analysis_model_profile_id, created_at, updated_at
) VALUES (?, ?, ?, '', ?, ?, '', ?, ?, ?, 'n8n', 'collector', '', '', ?, 'auto', 'failed', '', '', 'pending', ?, '{}', '{}', ?, ?, ?)
""",
(
job_id,
ctx["account_id"],
ctx["project_id"],
ctx["assistant_id"],
ctx["kb_id"],
"text",
"analysis",
"analysis_pipeline",
f"Failed {index}",
"boom",
ctx["model_id"],
now,
now,
),
)
failed_jobs.append(job_id)
retry_response = self.client.post(f"/v2/explore/jobs/{failed_jobs[0]}/retry", headers=headers)
self.assertEqual(retry_response.status_code, 200, retry_response.text)
retry_payload = retry_response.json()
self.assertEqual(retry_payload["status"], "queued")
bulk_response = self.client.post(
"/v2/admin/jobs/retry-failed",
headers=headers,
json={"project_id": ctx["project_id"], "limit": 10},
)
self.assertEqual(bulk_response.status_code, 200, bulk_response.text)
bulk_payload = bulk_response.json()
self.assertEqual(bulk_payload["count"], 1)
self.assertEqual(len(bulk_payload["retried"]), 1)
event_rows = self.core.db.fetch_all("SELECT event_type FROM job_events WHERE job_id = ? ORDER BY created_at ASC", (failed_jobs[0],))
event_types = [row["event_type"] for row in event_rows]
self.assertIn("job.retry.requested", event_types)
self.assertIn("job.retry.queued", event_types)
bulk_job = self.core.db.fetch_one("SELECT * FROM jobs WHERE id = ?", (failed_jobs[1],))
self.assertEqual(bulk_job["status"], "queued")
def test_backup_script_creates_consistent_snapshot(self) -> None:
ctx = self._seed_context("backup", exhausted=False)
backup_dir = Path(self.tempdir.name) / "backups"
script = ROOT / "scripts" / "backup_storyforge_sqlite.sh"
result = subprocess.run(
["bash", str(script)],
check=True,
text=True,
capture_output=True,
env={
**os.environ,
"DATABASE_PATH": str(self.core.db.path),
"BACKUP_DIR": str(backup_dir),
},
)
backup_path = Path(result.stdout.strip().splitlines()[-1])
self.assertTrue(backup_path.exists(), result.stdout)
with sqlite3.connect(backup_path) as conn:
account_count = conn.execute("SELECT COUNT(*) FROM accounts").fetchone()[0]
self.assertGreaterEqual(int(account_count), 1)