131 lines
5.2 KiB
Python
131 lines
5.2 KiB
Python
import os
|
||
import json
|
||
import time
|
||
from typing import Dict, Any, Iterator, Optional, List
|
||
from pathlib import Path
|
||
import modal
|
||
from jobs.job_base import JobBase
|
||
|
||
class JobTranscribeChunkModal(JobBase):
|
||
"""
|
||
チャンク群をModalのfaster-whisperで文字起こしするジョブ
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__(name=self.__class__.__name__)
|
||
self.description = "Transcribe Chunks via Modal (faster-whisper)"
|
||
|
||
# Modal 関数指定(環境変数で上書き可能)
|
||
self.modal_app = os.getenv("MODAL_ASR_APP", "whisper-transcribe-fw")
|
||
self.modal_func = os.getenv("MODAL_ASR_FUNC", "transcribe_audio")
|
||
|
||
# ASR 設定
|
||
self.model_name = os.getenv("ASR_MODEL", None) # 空なら Modal 側のデフォルト
|
||
self.lang = os.getenv("ASR_LANG", "ja")
|
||
|
||
# 実行ポリシー
|
||
self.req_timeout = int(os.getenv("ASR_TIMEOUT_SEC", "600"))
|
||
self.sleep_between = float(os.getenv("ASR_SLEEP_SEC", "0.3"))
|
||
self.max_retries = int(os.getenv("ASR_MAX_RETRIES", "2"))
|
||
self.retry_backoff = float(os.getenv("ASR_RETRY_BACKOFF", "1.5")) # 乗算
|
||
|
||
|
||
# ---------- パス ----------
|
||
|
||
def _manifest_path(self) -> Path:
|
||
return Path(self.status.chunk_manifest)
|
||
|
||
def _out_dir(self) -> Path:
|
||
return Path(self.status.transcript_dir)
|
||
|
||
# ---------- Modal 呼び出し ----------
|
||
|
||
def _transcribe_with_modal(self, wav_path: str) -> Dict[str, Any]:
|
||
"""Modal の Function を直接呼び出す(bytes渡し)"""
|
||
fn = modal.Function.from_name(self.modal_app, self.modal_func)
|
||
|
||
data = Path(wav_path).read_bytes()
|
||
kwargs = dict(filename=Path(wav_path).name, language=self.lang)
|
||
if self.model_name:
|
||
kwargs["model_name"] = self.model_name
|
||
|
||
# 軽いリトライ
|
||
wait = 1.0
|
||
last_err: Optional[Exception] = None
|
||
for attempt in range(self.max_retries + 1):
|
||
try:
|
||
# NOTE: Modal の .remote は同期戻り
|
||
res = fn.remote(data, **kwargs)
|
||
# 返り値は {"text","segments","words",...} を想定
|
||
return res
|
||
except Exception as e:
|
||
last_err = e
|
||
if attempt < self.max_retries:
|
||
self.logger.warning(f"Modal retry ({attempt+1}/{self.max_retries}) for {wav_path}: {e}")
|
||
time.sleep(wait)
|
||
wait *= self.retry_backoff
|
||
continue
|
||
break
|
||
raise last_err if last_err else RuntimeError("Unknown Modal ASR error")
|
||
|
||
# ---------- 実行本体 ----------
|
||
|
||
def execute(self):
|
||
self.logger.info(f"{self.name} execute started")
|
||
|
||
manifest_path = self._manifest_path()
|
||
if not manifest_path.exists():
|
||
raise FileNotFoundError(f"chunks manifest not found: {manifest_path}")
|
||
|
||
out_dir = self._out_dir()
|
||
if out_dir.exists():
|
||
# すでに変換済み
|
||
self.logger.info(f"Transcription already done: {out_dir}")
|
||
return
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
meta = json.loads(manifest_path.read_text(encoding="utf-8"))
|
||
chunks: List[Dict[str, Any]] = meta.get("chunks") or []
|
||
if not chunks:
|
||
self.logger.warning("No chunks found in manifest. Skipping.")
|
||
return
|
||
|
||
results: List[Dict[str, Any]] = []
|
||
|
||
for ch in chunks:
|
||
wav = ch["path"]
|
||
per_chunk_out = out_dir / f"{Path(wav).stem}.transcript.json"
|
||
|
||
# レジューム:保存済みならスキップして読み出し
|
||
if per_chunk_out.exists():
|
||
res: dict = json.loads(per_chunk_out.read_text(encoding="utf-8"))
|
||
else:
|
||
self.logger.info(f"ASR(modal): {wav}")
|
||
res: dict = self._transcribe_with_modal(wav)
|
||
per_chunk_out.write_text(json.dumps(res, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
time.sleep(self.sleep_between)
|
||
|
||
# abs_start 分だけ全体時刻に補正
|
||
offset = float(ch.get("abs_start", 0.0))
|
||
words = res.get("words") or []
|
||
for w in words:
|
||
if "start" in w: w["start"] = float(w["start"]) + offset
|
||
if "end" in w: w["end"] = float(w["end"]) + offset
|
||
segments = res.get("segments") or []
|
||
for s in segments:
|
||
if "start" in s: s["start"] = float(s["start"]) + offset
|
||
if "end" in s: s["end"] = float(s["end"]) + offset
|
||
|
||
results.append({
|
||
"chunk_id": ch["chunk_id"],
|
||
"abs_start": ch.get("abs_start", 0.0),
|
||
"abs_end": ch.get("abs_end", 0.0),
|
||
"text": res.get("text", ""),
|
||
"segments": segments,
|
||
"words": words,
|
||
})
|
||
|
||
# 統合保存(OpenAI版と同じ形)
|
||
all_out = out_dir / "all.transcripts.json"
|
||
all_out.write_text(json.dumps({"transcripts": results}, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
self.logger.info(f"Transcription merged: {all_out}") |