speech-to-text-pipeline/src/jobs/job_transcribe_chk_modal.py
2025-09-11 21:26:01 +09:00

131 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import time
from typing import Dict, Any, Iterator, Optional, List
from pathlib import Path
import modal
from jobs.job_base import JobBase
class JobTranscribeChunkModal(JobBase):
"""
チャンク群をModalのfaster-whisperで文字起こしするジョブ
"""
def __init__(self):
super().__init__(name=self.__class__.__name__)
self.description = "Transcribe Chunks via Modal (faster-whisper)"
# Modal 関数指定(環境変数で上書き可能)
self.modal_app = os.getenv("MODAL_ASR_APP", "whisper-transcribe-fw")
self.modal_func = os.getenv("MODAL_ASR_FUNC", "transcribe_audio")
# ASR 設定
self.model_name = os.getenv("ASR_MODEL", None) # 空なら Modal 側のデフォルト
self.lang = os.getenv("ASR_LANG", "ja")
# 実行ポリシー
self.req_timeout = int(os.getenv("ASR_TIMEOUT_SEC", "600"))
self.sleep_between = float(os.getenv("ASR_SLEEP_SEC", "0.3"))
self.max_retries = int(os.getenv("ASR_MAX_RETRIES", "2"))
self.retry_backoff = float(os.getenv("ASR_RETRY_BACKOFF", "1.5")) # 乗算
# ---------- パス ----------
def _manifest_path(self) -> Path:
return Path(self.status.chunk_manifest)
def _out_dir(self) -> Path:
return Path(self.status.transcript_dir)
# ---------- Modal 呼び出し ----------
def _transcribe_with_modal(self, wav_path: str) -> Dict[str, Any]:
"""Modal の Function を直接呼び出すbytes渡し"""
fn = modal.Function.from_name(self.modal_app, self.modal_func)
data = Path(wav_path).read_bytes()
kwargs = dict(filename=Path(wav_path).name, language=self.lang)
if self.model_name:
kwargs["model_name"] = self.model_name
# 軽いリトライ
wait = 1.0
last_err: Optional[Exception] = None
for attempt in range(self.max_retries + 1):
try:
# NOTE: Modal の .remote は同期戻り
res = fn.remote(data, **kwargs)
# 返り値は {"text","segments","words",...} を想定
return res
except Exception as e:
last_err = e
if attempt < self.max_retries:
self.logger.warning(f"Modal retry ({attempt+1}/{self.max_retries}) for {wav_path}: {e}")
time.sleep(wait)
wait *= self.retry_backoff
continue
break
raise last_err if last_err else RuntimeError("Unknown Modal ASR error")
# ---------- 実行本体 ----------
def execute(self):
self.logger.info(f"{self.name} execute started")
manifest_path = self._manifest_path()
if not manifest_path.exists():
raise FileNotFoundError(f"chunks manifest not found: {manifest_path}")
out_dir = self._out_dir()
if out_dir.exists():
# すでに変換済み
self.logger.info(f"Transcription already done: {out_dir}")
return
out_dir.mkdir(parents=True, exist_ok=True)
meta = json.loads(manifest_path.read_text(encoding="utf-8"))
chunks: List[Dict[str, Any]] = meta.get("chunks") or []
if not chunks:
self.logger.warning("No chunks found in manifest. Skipping.")
return
results: List[Dict[str, Any]] = []
for ch in chunks:
wav = ch["path"]
per_chunk_out = out_dir / f"{Path(wav).stem}.transcript.json"
# レジューム:保存済みならスキップして読み出し
if per_chunk_out.exists():
res: dict = json.loads(per_chunk_out.read_text(encoding="utf-8"))
else:
self.logger.info(f"ASR(modal): {wav}")
res: dict = self._transcribe_with_modal(wav)
per_chunk_out.write_text(json.dumps(res, ensure_ascii=False, indent=2), encoding="utf-8")
time.sleep(self.sleep_between)
# abs_start 分だけ全体時刻に補正
offset = float(ch.get("abs_start", 0.0))
words = res.get("words") or []
for w in words:
if "start" in w: w["start"] = float(w["start"]) + offset
if "end" in w: w["end"] = float(w["end"]) + offset
segments = res.get("segments") or []
for s in segments:
if "start" in s: s["start"] = float(s["start"]) + offset
if "end" in s: s["end"] = float(s["end"]) + offset
results.append({
"chunk_id": ch["chunk_id"],
"abs_start": ch.get("abs_start", 0.0),
"abs_end": ch.get("abs_end", 0.0),
"text": res.get("text", ""),
"segments": segments,
"words": words,
})
# 統合保存OpenAI版と同じ形
all_out = out_dir / "all.transcripts.json"
all_out.write_text(json.dumps({"transcripts": results}, ensure_ascii=False, indent=2), encoding="utf-8")
self.logger.info(f"Transcription merged: {all_out}")