最新ソースを追加する
This commit is contained in:
parent
c09b288ff6
commit
da2e40f478
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,5 +1,6 @@
|
|||||||
data/
|
data/
|
||||||
!.gitkeep
|
!data/.gitkeep
|
||||||
|
x_cookies.json
|
||||||
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|||||||
76
docs/sns.md
Normal file
76
docs/sns.md
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
## X
|
||||||
|
|
||||||
|
### 公式API
|
||||||
|
|
||||||
|
* ポストの投稿
|
||||||
|
* 自社のサービスからXに投稿できる
|
||||||
|
* ポストの取得、検索
|
||||||
|
* 特定のキーワードやハッシュタグを含むポストを検索できる
|
||||||
|
|
||||||
|
|
||||||
|
無料プランの制限
|
||||||
|
|
||||||
|
| プラン | FREE |
|
||||||
|
| ------------ | ---------- |
|
||||||
|
| ポスト投稿 | 1,500件/月 |
|
||||||
|
| 月間投稿取得 | |
|
||||||
|
| | |
|
||||||
|
|
||||||
|
|
||||||
|
**手順**
|
||||||
|
|
||||||
|
* https://qiita.com/dl_from_scratch/items/75d3bb60fc2a93da9917
|
||||||
|
* https://qiita.com/neru-dev/items/857cc27fd69411496388
|
||||||
|
* https://zenn.dev/masuda1112/articles/2024-10-26-craete-post-by-python
|
||||||
|
|
||||||
|
* アカウントからAPIキーを発行する
|
||||||
|
* https://developer.twitter.com/ja
|
||||||
|
* APIの利用目的の記載が必要
|
||||||
|
* APIKEYを取得する
|
||||||
|
* User authentication settingsで権限を設定する
|
||||||
|
* Read(デフォルト) -> Read and write
|
||||||
|
|
||||||
|
sample
|
||||||
|
|
||||||
|
```txt
|
||||||
|
I plan to use the X API to collect and analyze public conversations (tweets and replies) related to AI, technology, and news.
|
||||||
|
The purpose is non-commercial research and educational use, such as understanding discussion trends and generating summary reports.
|
||||||
|
Data will not be shared with third parties and will only be stored temporarily for analysis.
|
||||||
|
All usage will comply with X’s Developer Policy and data protection requirements.
|
||||||
|
私は X API を利用して、AI、テクノロジー、ニュースに関連する公開の会話(ツイートやリプライ)を収集・分析する予定です。
|
||||||
|
目的は、議論の動向を理解したり要約レポートを作成したりするなど、非営利の研究や教育利用です。
|
||||||
|
データは分析のために一時的に保存するだけで、第三者と共有することはありません。
|
||||||
|
すべての利用は X の開発者ポリシーとデータ保護要件に従います。
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 取得フィールドについて
|
||||||
|
|
||||||
|
**tweet.fields**
|
||||||
|
|
||||||
|
* created_at ツイートの投稿日時 (UTC, ISO8601形式)
|
||||||
|
* author_id 投稿者ユーザーの ID (数値文字列)
|
||||||
|
* conversation_id 会話スレッドを一意に識別する ID
|
||||||
|
(同じ会話に属する全ツイートで同じIDになる)
|
||||||
|
* public_metrics インタラクション数
|
||||||
|
(retweet_count, reply_count, like_count, quote_count など)
|
||||||
|
* referenced_tweets このツイートが返信・引用リツイート・リツイートかどうかを示す情報
|
||||||
|
|
||||||
|
|
||||||
|
**expansions**
|
||||||
|
|
||||||
|
IDだけではなく関連オブジェクト(ユーザーやメディアなど)を
|
||||||
|
「展開」して返す指定。
|
||||||
|
|
||||||
|
* author_id (author_id を展開)
|
||||||
|
* ユーザー情報を includes.users に含めて返す。
|
||||||
|
|
||||||
|
|
||||||
|
**user.fields**
|
||||||
|
ユーザーオブジェクトに欲しい追加情報を指定。
|
||||||
|
|
||||||
|
* username @なしのスクリーンネーム (例: jack)
|
||||||
|
* name 表示名 (例: Jack Dorsey)
|
||||||
|
* verified 認証済みアカウントかどうか(True/False)
|
||||||
|
|
||||||
|
|
||||||
|
### スクレイピング
|
||||||
95
examples/example_csv.py
Normal file
95
examples/example_csv.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(".env")
|
||||||
|
|
||||||
|
import lib.custom_logger as get_logger
|
||||||
|
logger = get_logger.get_logger(level=10)
|
||||||
|
|
||||||
|
from models.csv_rss_item import RSSItem
|
||||||
|
from lib.rss_reader_client import RSSReaderClient
|
||||||
|
from lib.csv_collector import CSVWriter,CSVReader,CSVEditMapper,CSVAnalyzer
|
||||||
|
from utils.translate_deepl import DeepLTranslateClient
|
||||||
|
|
||||||
|
def example_fetch():
|
||||||
|
url="https://openai.com/news/rss.xml"
|
||||||
|
items = RSSReaderClient.fetch(url,from_at="2025-09-12 21:00:00+09:00")
|
||||||
|
logger.info(f"Fetched {len(items)} items")
|
||||||
|
|
||||||
|
# example_fetch()
|
||||||
|
|
||||||
|
def example_writer():
|
||||||
|
url="https://openai.com/news/rss.xml"
|
||||||
|
items = RSSReaderClient.fetch(url)
|
||||||
|
csv_data = RSSItem.to_csv_from_items(items)
|
||||||
|
CSVWriter.write(
|
||||||
|
records=csv_data,
|
||||||
|
domain="tech/ai",
|
||||||
|
layer="bronze",
|
||||||
|
event="openai_news",
|
||||||
|
is_year=True, is_month=True, part=1,
|
||||||
|
)
|
||||||
|
example_writer()
|
||||||
|
|
||||||
|
def example_reader():
|
||||||
|
client = DeepLTranslateClient()
|
||||||
|
file_path = "data/tech/ai/bronze/y=2025/m=09/openai_news_2025-09-15_part-001.csv"
|
||||||
|
data = CSVReader.read(file_path)
|
||||||
|
header_map = CSVReader.header_map(data[0])
|
||||||
|
logger.info(f"header_map: {header_map}")
|
||||||
|
mapper = CSVEditMapper(header_map=header_map)
|
||||||
|
mapper.add_column("uid")
|
||||||
|
mapper.add_column("title")
|
||||||
|
mapper.add_column("link")
|
||||||
|
mapper.add_column("summary")
|
||||||
|
def call_back_text_ja(row_idx:int,row:list,header_map:dict) -> str:
|
||||||
|
title = mapper.get_column_values("title",row)
|
||||||
|
summary = mapper.get_column_values("summary",row)
|
||||||
|
val = f"{title}\n\n{summary}"
|
||||||
|
val_ja = client.translate(val, from_lang="en", to_lang="ja")
|
||||||
|
return val_ja
|
||||||
|
mapper.add_callback("text_ja", call_back_text_ja)
|
||||||
|
mapper.add_column("published_at", key_name="published_parsed")
|
||||||
|
edited_data = mapper.edit(data)
|
||||||
|
edit_filename = "data/tech/ai/silver_work/y=2025/m=09/openai_news_2025-09-15_part-001_edit01.csv"
|
||||||
|
CSVWriter.write_with_filename(
|
||||||
|
records=edited_data,
|
||||||
|
filename=edit_filename,
|
||||||
|
is_update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# example_reader()
|
||||||
|
|
||||||
|
def example_reader2():
|
||||||
|
file_path = "data/tech/ai/silver_work/y=2025/m=09/openai_news_2025-09-15_part-001_edit01.csv"
|
||||||
|
data = CSVReader.read(file_path)
|
||||||
|
header_map = CSVReader.header_map(data[0])
|
||||||
|
logger.info(f"header_map: {header_map}")
|
||||||
|
mapper = CSVEditMapper(header_map=header_map)
|
||||||
|
mapper.auto_columns()
|
||||||
|
mapper.add_value("created_at", value="2025-09-15 00:00:00+00:00")
|
||||||
|
edited_data = mapper.edit(data)
|
||||||
|
edit_filename = "data/tech/ai/silver_work/y=2025/m=09/openai_news_2025-09-15_part-001_edit02.csv"
|
||||||
|
CSVWriter.write_with_filename(
|
||||||
|
records=edited_data,
|
||||||
|
filename=edit_filename,
|
||||||
|
is_update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# example_reader2()
|
||||||
|
|
||||||
|
def example_edit_priod():
|
||||||
|
file_path = "data/tech/ai/silver_work/y=2025/m=09/openai_news_2025-09-15_part-001_edit02.csv"
|
||||||
|
data = CSVReader.read(file_path)
|
||||||
|
CSVAnalyzer.write_separated_month(
|
||||||
|
data,
|
||||||
|
domain="tech/ai",
|
||||||
|
layer="silver",
|
||||||
|
event="openai_news",
|
||||||
|
)
|
||||||
|
|
||||||
|
# example_edit_priod()
|
||||||
|
|
||||||
29
examples/example_duckdb.py
Normal file
29
examples/example_duckdb.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
||||||
|
|
||||||
|
import lib.custom_logger as get_logger
|
||||||
|
logger = get_logger.get_logger(level=10)
|
||||||
|
|
||||||
|
from providers.duck_db_provider import DuckDBProvider
|
||||||
|
|
||||||
|
def example_duckdb():
|
||||||
|
logger.info("Starting example_duckdb function.")
|
||||||
|
file_path = "data/tech/ai/bronze/y=2025/m=*/openai_news_*.csv"
|
||||||
|
provider = DuckDBProvider()
|
||||||
|
result = provider.max_value(
|
||||||
|
file_path=file_path,
|
||||||
|
column="published_parsed",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("latest published_parsed:", result)
|
||||||
|
|
||||||
|
example_duckdb()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# con.execute(f"CREATE TABLE IF NOT EXISTS data AS SELECT * FROM read_csv_auto('{file_path}')")
|
||||||
|
# logger.info("Table 'data' created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
19
examples/example_pipeline.py
Normal file
19
examples/example_pipeline.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
||||||
|
|
||||||
|
import lib.custom_logger as get_logger
|
||||||
|
logger = get_logger.get_logger(level=10)
|
||||||
|
|
||||||
|
from pipeline.pipeline_base import PipelineBase
|
||||||
|
from jobs.job_collect_rss_open_ai import JobCollectRSSOpenAI
|
||||||
|
|
||||||
|
def example_pipeline():
|
||||||
|
pipeline = PipelineBase()
|
||||||
|
logger.info("Pipeline initialized with context: %s", pipeline.context)
|
||||||
|
# Here you can add jobs to the pipeline and run it
|
||||||
|
# e.g., pipeline.add_job(SomeJob(context=pipeline.context))
|
||||||
|
pipeline.add_job(JobCollectRSSOpenAI())
|
||||||
|
pipeline.run()
|
||||||
|
|
||||||
|
example_pipeline()
|
||||||
37
examples/example_scraper.py
Normal file
37
examples/example_scraper.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import asyncio
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(".env")
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger(level=10)
|
||||||
|
|
||||||
|
from providers.scraper.anthropic_scraper_provider import AnthropicScraperProvider
|
||||||
|
from models.csv_scrape_item import ScrapeItem
|
||||||
|
from lib.csv_collector import CSVWriter
|
||||||
|
|
||||||
|
|
||||||
|
def example_scraper():
|
||||||
|
client = AnthropicScraperProvider()
|
||||||
|
items = client.crawl_sync()
|
||||||
|
logger.info(f"Scraped {len(items)} items")
|
||||||
|
csv_data = ScrapeItem.to_csv_from_items(items)
|
||||||
|
CSVWriter.write(
|
||||||
|
records=csv_data,
|
||||||
|
domain="tech/ai",
|
||||||
|
layer="bronze",
|
||||||
|
event="anthropic_news",
|
||||||
|
is_year=True, is_month=True, part=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# async def run():
|
||||||
|
# async with httpx.AsyncClient() as client:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
example_scraper()
|
||||||
29
examples/example_sns.py
Normal file
29
examples/example_sns.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# pip install requests requests-oauthlib
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
from requests_oauthlib import OAuth1
|
||||||
|
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
||||||
|
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(".env")
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger(level=10)
|
||||||
|
|
||||||
|
from providers.sns.api_sns_x import APISNSX
|
||||||
|
|
||||||
|
def example_get_tweet():
|
||||||
|
items = APISNSX.search_recent_tweets(
|
||||||
|
query="OpenAI lang:ja -is:retweet",
|
||||||
|
max_results=10
|
||||||
|
)
|
||||||
|
logger.info(f"Found {len(items.get('data', []))} tweets")
|
||||||
|
for tweet in items.get("data", []):
|
||||||
|
logger.info(f"- {tweet['id']}: {tweet['text']}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
example_get_tweet()
|
||||||
93
examples/example_sns_scraper.py
Normal file
93
examples/example_sns_scraper.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# pip install requests requests-oauthlib
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
|
||||||
|
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(".env")
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger(level=10)
|
||||||
|
|
||||||
|
|
||||||
|
from providers.sns.x_sns_scraper import XScraper
|
||||||
|
|
||||||
|
async def first_time_login():
|
||||||
|
bot = XScraper(storage_state="x_cookies.json", headless=False, slow_mo=50)
|
||||||
|
await bot.start()
|
||||||
|
await bot.login_manual()
|
||||||
|
input("ログイン完了後に Enter を押してください...")
|
||||||
|
ok = await bot.is_logged_in()
|
||||||
|
print("Logged in?", ok)
|
||||||
|
await bot.save_state()
|
||||||
|
await bot.stop()
|
||||||
|
# asyncio.run(first_time_login())
|
||||||
|
|
||||||
|
async def run_headless():
|
||||||
|
bot = XScraper(storage_state="x_cookies.json", headless=True)
|
||||||
|
await bot.start()
|
||||||
|
print("already logged in?", await bot.is_logged_in())
|
||||||
|
# ここに処理を書く(検索/会話取得など、次のステップで実装)
|
||||||
|
items = await bot.search_tweets("OpenAI lang:ja -is:retweet", 30)
|
||||||
|
logger.info(f"Found {len(items)} tweets")
|
||||||
|
for tweet in items :
|
||||||
|
logger.info(f"- {tweet['id']}: {tweet['text']}")
|
||||||
|
|
||||||
|
|
||||||
|
await bot.stop()
|
||||||
|
asyncio.run(run_headless())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# async def example_get_tweet_scraper():
|
||||||
|
# bot = XScraper(storage_state="x_cookies.json", headless=False, slow_mo=100)
|
||||||
|
# await bot.start()
|
||||||
|
|
||||||
|
# # 初回だけ:手動ログインして Cookie を保存
|
||||||
|
# # await bot.login_manual()
|
||||||
|
# # await asyncio.sleep(240) # 60秒待つ
|
||||||
|
|
||||||
|
# # 検索で収集
|
||||||
|
# res = await bot.search_live("OpenAI lang:ja -is:retweet", scroll_secs=6)
|
||||||
|
# print("search tweets:", len(res))
|
||||||
|
# if res:
|
||||||
|
# print(res[0])
|
||||||
|
|
||||||
|
# await bot.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(example_get_tweet_scraper())
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from playwright.async_api import async_playwright, TimeoutError
|
||||||
|
STATE = "x_cookies.json"
|
||||||
|
|
||||||
|
async def save_state_once():
|
||||||
|
async with async_playwright() as p:
|
||||||
|
browser = await p.chromium.launch(headless=False, slow_mo=50)
|
||||||
|
ctx = await browser.new_context()
|
||||||
|
page = await ctx.new_page()
|
||||||
|
await page.goto("https://x.com/login", wait_until="domcontentloaded")
|
||||||
|
input("ログインを完了したら Enter...")
|
||||||
|
# ホームが開ける=ログイン確認してから保存
|
||||||
|
await page.goto("https://x.com/home", wait_until="domcontentloaded")
|
||||||
|
await page.wait_for_selector('[aria-label="Account menu"]', timeout=15000)
|
||||||
|
await ctx.storage_state(path=STATE) # ★ここで保存
|
||||||
|
await ctx.close(); await browser.close()
|
||||||
|
|
||||||
|
async def use_saved_state_headless():
|
||||||
|
async with async_playwright() as p:
|
||||||
|
browser = await p.chromium.launch(headless=True)
|
||||||
|
ctx = await browser.new_context(storage_state=STATE)
|
||||||
|
page = await ctx.new_page()
|
||||||
|
await page.goto("https://x.com/home", wait_until="domcontentloaded")
|
||||||
|
# ここでログイン要求が出るなら state が効いていない
|
||||||
|
|
||||||
|
|
||||||
|
# save_state_once()
|
||||||
|
# asyncio.run(save_state_once())
|
||||||
|
asyncio.run(use_saved_state_headless())
|
||||||
@ -17,4 +17,13 @@ deepl
|
|||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
# sentencepiece
|
# sentencepiece
|
||||||
# torch
|
# torch
|
||||||
|
|
||||||
|
# scraper
|
||||||
|
httpx[http2]
|
||||||
|
selectolax
|
||||||
|
|
||||||
|
playwright==1.52.0
|
||||||
|
|
||||||
|
# SNS(X)
|
||||||
|
requests-oauthlib
|
||||||
0
src/jobs/__init__.py
Normal file
0
src/jobs/__init__.py
Normal file
24
src/jobs/job_base.py
Normal file
24
src/jobs/job_base.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from lib.custom_logger import get_logger
|
||||||
|
from pipe_context import PipeContext
|
||||||
|
class JobResult:
|
||||||
|
"""ジョブの実行結果"""
|
||||||
|
def __init__(self, success: bool, message: str = "", data: any = None):
|
||||||
|
self.success = success
|
||||||
|
self.message = message
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
|
class JobBase():
|
||||||
|
|
||||||
|
"""ジョブの基底クラス"""
|
||||||
|
def __init__(self, name="JobBase",context:PipeContext=None):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.context = context or PipeContext()
|
||||||
|
self.name = name
|
||||||
|
self.logger.info(f"{self.name} initialized")
|
||||||
|
|
||||||
|
|
||||||
|
def execute(self)->JobResult:
|
||||||
|
"""ジョブの実行"""
|
||||||
|
self.logger.info(f"{self.name} execute called")
|
||||||
|
raise NotImplementedError("Subclasses must implement this method")
|
||||||
53
src/jobs/job_collect_rss_open_ai.py
Normal file
53
src/jobs/job_collect_rss_open_ai.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from jobs.job_base import JobBase, JobResult
|
||||||
|
from lib.csv_collector.csv_writer import CSVWriter
|
||||||
|
from lib.rss_reader_client import RSSReaderClient,RSSItem
|
||||||
|
from providers.duck_db_provider import DuckDBProvider
|
||||||
|
from pipe_context import PipeContext
|
||||||
|
|
||||||
|
class JobCollectRSSOpenAI(JobBase):
|
||||||
|
"""OpenAIのRSSフィードを収集するジョブ"""
|
||||||
|
|
||||||
|
def __init__(self, context: PipeContext = None):
|
||||||
|
super().__init__(name=self.__class__.__name__, context=context )
|
||||||
|
self.description = "Collect RSS feeds from OpenAI"
|
||||||
|
self.domain = "tech/ai"
|
||||||
|
self.layer = "bronze"
|
||||||
|
self.event = "openai_news"
|
||||||
|
self.is_year = True
|
||||||
|
self.is_month = True
|
||||||
|
|
||||||
|
def execute(self):
|
||||||
|
try:
|
||||||
|
self.logger.info(f"{self.name} started")
|
||||||
|
# RSSフィードを収集する処理を実装
|
||||||
|
url = "https://openai.com/news/rss.xml"
|
||||||
|
|
||||||
|
# CSVに保存されている最新日時を取得する
|
||||||
|
provider = DuckDBProvider()
|
||||||
|
published_parsed_max = provider.max_value(
|
||||||
|
file_path=f"data/{self.domain}/{self.layer}/y=*/m=*/{self.event}_*.csv",
|
||||||
|
column="published_parsed",
|
||||||
|
)
|
||||||
|
|
||||||
|
items = RSSReaderClient.fetch(url, from_at=published_parsed_max )
|
||||||
|
if not items or len(items) == 0:
|
||||||
|
self.logger.info("No new items found.")
|
||||||
|
return JobResult(success=True, message="No new items found.")
|
||||||
|
|
||||||
|
# 取得したアイテムをCSV形式に変換して保存
|
||||||
|
csv_data = RSSItem.to_csv_from_items(items)
|
||||||
|
filename = CSVWriter.write(
|
||||||
|
records=csv_data,
|
||||||
|
domain=self.domain,
|
||||||
|
layer=self.layer,
|
||||||
|
event=self.event,
|
||||||
|
is_year=self.is_year,
|
||||||
|
is_month=self.is_month,
|
||||||
|
)
|
||||||
|
self.context.set("output_filename",filename)
|
||||||
|
self.logger.info(f"{self.name} finished")
|
||||||
|
return JobResult(success=True, message="Job completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error in {self.name} during setup: {e}")
|
||||||
|
return JobResult(success=False, message=str(e))
|
||||||
12
src/lib/csv_collector/__init__.py
Normal file
12
src/lib/csv_collector/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from .csv_writer import CSVWriter
|
||||||
|
from .csv_reader import CSVReader
|
||||||
|
from .csv_editor import CSVEditColumn,CSVEditMapper
|
||||||
|
from .csv_analyzer import CSVAnalyzer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CSVWriter",
|
||||||
|
"CSVReader",
|
||||||
|
"CSVEditColumn",
|
||||||
|
"CSVEditMapper",
|
||||||
|
"CSVAnalyzer",
|
||||||
|
]
|
||||||
118
src/lib/csv_collector/csv_analyzer.py
Normal file
118
src/lib/csv_collector/csv_analyzer.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
from typing import Union
|
||||||
|
from utils.types import DataLayer
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
from .csv_writer import CSVWriter
|
||||||
|
from .csv_reader import CSVReader
|
||||||
|
|
||||||
|
class CSVAnalyzer:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _separate_month_to_df(
|
||||||
|
cls,
|
||||||
|
header: list,
|
||||||
|
data_rows: list,
|
||||||
|
date_key: str = "published_at",
|
||||||
|
tz: str | None = None) -> pd.DataFrame | None:
|
||||||
|
|
||||||
|
if not data_rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = pd.DataFrame(data_rows, columns=header)
|
||||||
|
# 日付のデータ列を加工する(datetime型に変換,タイムゾーン変換)
|
||||||
|
df[date_key] = pd.to_datetime(df[date_key], errors="coerce", utc=True)
|
||||||
|
if tz:
|
||||||
|
df[date_key] = df[date_key].dt.tz_convert(ZoneInfo(tz))
|
||||||
|
# 年月列を追加
|
||||||
|
df["year_month"] = df[date_key].dt.to_period("M")
|
||||||
|
# 7) グループごとにdictリストへ
|
||||||
|
return df
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def separate_month_to_dict(
|
||||||
|
cls,
|
||||||
|
header: list,
|
||||||
|
data_rows: list,
|
||||||
|
date_key: str = "published_at",
|
||||||
|
tz: str | None = None) -> dict[str, list[dict]] | None:
|
||||||
|
"""
|
||||||
|
年月ごとにデータを分割する(list of list形式-> dict of list of dict形式)
|
||||||
|
"""
|
||||||
|
df = cls._separate_month_to_df(header, data_rows, date_key, tz)
|
||||||
|
if df is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
str(ym): g.drop(columns=["year_month"]).to_dict(orient="records")
|
||||||
|
for ym, g in df.groupby("year_month", sort=True)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write_separated_month(
|
||||||
|
cls,
|
||||||
|
records,
|
||||||
|
domain: str,
|
||||||
|
event: str,
|
||||||
|
layer:Union[str, DataLayer],
|
||||||
|
prefix: str = None,
|
||||||
|
data_format: str = "%Y-%m",
|
||||||
|
is_year: bool=True,
|
||||||
|
is_month: bool=True,
|
||||||
|
data_key: str = "published_at",
|
||||||
|
tz: str | None = None,
|
||||||
|
):
|
||||||
|
"""年月ごとにデータを分割してCSVファイルに保存する"""
|
||||||
|
if not records or len(records) < 2:
|
||||||
|
logger.warning("No records to process.")
|
||||||
|
return
|
||||||
|
header = records[0]
|
||||||
|
data_rows = records[1:]
|
||||||
|
|
||||||
|
df = cls._separate_month_to_df(header, data_rows, data_key, tz)
|
||||||
|
if df is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for ym, g in df.groupby("year_month", sort=True):
|
||||||
|
logger.info(f"Processing year-month: {ym}")
|
||||||
|
y, m = str(ym).split("-")
|
||||||
|
folder_path = CSVWriter.get_filepath(
|
||||||
|
domain=domain,
|
||||||
|
layer=layer)
|
||||||
|
if is_year:
|
||||||
|
folder_path = f"{folder_path}/y={y}"
|
||||||
|
if is_month:
|
||||||
|
folder_path = f"{folder_path}/m={m}"
|
||||||
|
|
||||||
|
filename = CSVWriter.get_filename(
|
||||||
|
event=event,
|
||||||
|
prefix=prefix,
|
||||||
|
date_format=data_format,
|
||||||
|
dt=str(ym) + "-01",
|
||||||
|
extension=".csv"
|
||||||
|
)
|
||||||
|
fpath = os.path.join(folder_path, filename)
|
||||||
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
|
logger.info(f"Writing to file: {fpath}")
|
||||||
|
g.drop(columns=["year_month"]).to_csv(fpath, index=False, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# result = {}
|
||||||
|
# for year_month, group in df.groupby('year_month'):
|
||||||
|
# year = year_month.year
|
||||||
|
# month = year_month.month
|
||||||
|
# logger.info(f"y={year}/m={month:02d}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
110
src/lib/csv_collector/csv_editor.py
Normal file
110
src/lib/csv_collector/csv_editor.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
|
||||||
|
# import os
|
||||||
|
# import csv
|
||||||
|
from typing import Optional, TypeVar,Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from .csv_reader import CSVReader
|
||||||
|
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
ColCallback = Callable[[int, list, dict], T]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CSVEditColumn():
|
||||||
|
"""CSV編集用の列情報"""
|
||||||
|
name: str
|
||||||
|
value: any = None
|
||||||
|
key_name: str = None
|
||||||
|
cb: Optional[ColCallback] = None
|
||||||
|
|
||||||
|
def execute(self, row_index: int, row: list, header_map: dict) -> any:
|
||||||
|
"""値を取得する"""
|
||||||
|
try:
|
||||||
|
if self.cb:
|
||||||
|
return self.cb(row_index, row, header_map)
|
||||||
|
elif self.key_name and self.key_name in header_map:
|
||||||
|
index = header_map[self.key_name]
|
||||||
|
return row[index]
|
||||||
|
else:
|
||||||
|
return self.value
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in CSVEditColumn.execute: {e}")
|
||||||
|
logger.error(f"row_index: {row_index}, row: {row}, header_map: {header_map}")
|
||||||
|
logger.error(f"Column info - name: {self.name}, value: {self.value}, key_name: {self.key_name}, cb: {self.cb}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
class CSVEditMapper:
|
||||||
|
"""CSV編集用のマッパー"""
|
||||||
|
def __init__(self, header_map: dict = None):
|
||||||
|
self.columns: list[CSVEditColumn] = []
|
||||||
|
self.header_map: dict = header_map if header_map else {}
|
||||||
|
|
||||||
|
def add(self, column: CSVEditColumn):
|
||||||
|
self.columns.append(column)
|
||||||
|
|
||||||
|
def add_column(self, name: str, key_name: str = None):
|
||||||
|
if not key_name:
|
||||||
|
key_name = name
|
||||||
|
self.columns.append(CSVEditColumn(name, None, key_name))
|
||||||
|
|
||||||
|
def add_value(self, name: str, value: any):
|
||||||
|
self.columns.append(CSVEditColumn(name, value))
|
||||||
|
|
||||||
|
def add_callback(self, name: str, cb: callable):
|
||||||
|
self.columns.append(CSVEditColumn(name, cb=cb))
|
||||||
|
|
||||||
|
def auto_columns(self):
|
||||||
|
"""既存のヘッダー情報から自動的に列を追加する"""
|
||||||
|
if not self.header_map or len(self.header_map) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 自動的に追加するが順番はインデックス順
|
||||||
|
sorted_items = sorted(self.header_map.items(), key=lambda item: item[1])
|
||||||
|
for key, idx in sorted_items:
|
||||||
|
self.add_column(name=key, key_name=key)
|
||||||
|
|
||||||
|
def get_column_values(self,key_name:str,row,null_value:any=None) -> any:
|
||||||
|
idx = self.header_map[key_name]
|
||||||
|
if idx is None or idx < 0:
|
||||||
|
return null_value
|
||||||
|
|
||||||
|
return row[idx]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def edit(self, records: list[list]) -> list[list]:
|
||||||
|
"""CSVデータを編集する"""
|
||||||
|
new_records = []
|
||||||
|
# ヘッダー行を追加する
|
||||||
|
header = []
|
||||||
|
for col in self.columns:
|
||||||
|
header.append(col.name)
|
||||||
|
new_records.append(header)
|
||||||
|
if not records or len(records) < 2:
|
||||||
|
return new_records
|
||||||
|
|
||||||
|
if self.header_map is None or len(self.header_map) == 0:
|
||||||
|
self.header_map = CSVReader.header_map(records[0])
|
||||||
|
|
||||||
|
# データ加工を実行する
|
||||||
|
for i,rows in enumerate(records[1:]):
|
||||||
|
new_row = []
|
||||||
|
for col in self.columns:
|
||||||
|
_value = col.execute(i, rows, self.header_map)
|
||||||
|
new_row.append(_value)
|
||||||
|
new_records.append(new_row)
|
||||||
|
|
||||||
|
return new_records
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
39
src/lib/csv_collector/csv_reader.py
Normal file
39
src/lib/csv_collector/csv_reader.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import os
|
||||||
|
import csv
|
||||||
|
from typing import List,Union
|
||||||
|
from datetime import datetime
|
||||||
|
from utils.types import DataLayer
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
class CSVReader:
|
||||||
|
"""CSVファイル書き込みユーティリティ"""
|
||||||
|
BASE_DIR = "data"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, file_path: str) -> List[any]:
|
||||||
|
"""CSVファイルを配列として読み込む"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.warning(f"File not found: {file_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(file_path, mode="r", newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def read_dict(cls, file_path: str) -> List[dict]:
|
||||||
|
"""CSVファイルを読み込む(辞書型)"""
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.warning(f"File not found: {file_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(file_path, mode="r", newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def header_map(cls, headers: list) -> dict[str,int]:
|
||||||
|
"""CSV配列のヘッダー情報よりマッピング辞書を生成"""
|
||||||
|
return {h: i for i, h in enumerate(headers)}
|
||||||
162
src/lib/csv_collector/csv_writer.py
Normal file
162
src/lib/csv_collector/csv_writer.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import os
|
||||||
|
import csv
|
||||||
|
from typing import List,Union
|
||||||
|
from datetime import datetime
|
||||||
|
from utils.types import DataLayer
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CSVWriter:
|
||||||
|
"""CSVファイル書き込みユーティリティ"""
|
||||||
|
BASE_DIR = "data"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_filepath(cls,
|
||||||
|
domain: str,
|
||||||
|
layer:Union[str, DataLayer],
|
||||||
|
is_year: bool=False,
|
||||||
|
is_month: bool=False,
|
||||||
|
is_day: bool=False,
|
||||||
|
is_hour: bool=False,
|
||||||
|
dt: Union[str,datetime]=None
|
||||||
|
) -> str:
|
||||||
|
"""フォルダパスを生成する"""
|
||||||
|
parts = [cls.BASE_DIR]
|
||||||
|
parts.append(domain)
|
||||||
|
parts.append(layer)
|
||||||
|
if dt is None:
|
||||||
|
dt = datetime.now()
|
||||||
|
elif isinstance(dt, str):
|
||||||
|
dt = datetime.fromisoformat(dt)
|
||||||
|
if is_year:
|
||||||
|
parts.append(f"y={dt.strftime('%Y')}")
|
||||||
|
if is_month:
|
||||||
|
parts.append(f"m={dt.strftime('%m')}")
|
||||||
|
if is_day:
|
||||||
|
parts.append(f"d={dt.strftime('%d')}")
|
||||||
|
if is_hour:
|
||||||
|
parts.append(f"h={dt.strftime('%H')}")
|
||||||
|
folder_path = os.path.join(*parts)
|
||||||
|
logger.debug(f"Generated CSV folder path: {folder_path}")
|
||||||
|
return os.path.join(*parts)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_filename(
|
||||||
|
cls,
|
||||||
|
event: str,
|
||||||
|
prefix: str = None,
|
||||||
|
date_format: str = "%Y-%m-%d",
|
||||||
|
dt: Union[str,datetime] = None,
|
||||||
|
part: int = None,
|
||||||
|
extension: str = ".csv") -> str:
|
||||||
|
"""
|
||||||
|
CSVファイルのパスを生成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str, optional): ファイル名の接頭辞. Defaults to None.
|
||||||
|
date_format (str, optional): 日付フォーマット. Defaults to None. 例: "%Y-%m-%d"
|
||||||
|
dt (datetime, optional): 日付情報. Defaults to None.
|
||||||
|
part (int, optional): パーティション番号. Defaults to None.
|
||||||
|
extension (str, optional): ファイル拡張子. Defaults to ".csv".
|
||||||
|
"""
|
||||||
|
file_names_part = []
|
||||||
|
if prefix:
|
||||||
|
file_names_part.append(prefix)
|
||||||
|
file_names_part.append(event)
|
||||||
|
|
||||||
|
if date_format:
|
||||||
|
# 日時データに変換
|
||||||
|
if dt is None:
|
||||||
|
dt = datetime.now()
|
||||||
|
elif isinstance(dt, str):
|
||||||
|
dt = datetime.fromisoformat(dt)
|
||||||
|
date_str = dt.strftime(date_format)
|
||||||
|
file_names_part.append(date_str)
|
||||||
|
|
||||||
|
if part is not None:
|
||||||
|
file_names_part.append(f"part-{part:03d}")
|
||||||
|
file_name = "_".join(file_names_part) + extension
|
||||||
|
logger.debug(f"Generated CSV file name: {file_name}")
|
||||||
|
return file_name
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write(
|
||||||
|
cls,
|
||||||
|
records:List,
|
||||||
|
domain:str,
|
||||||
|
layer:Union[str, DataLayer],
|
||||||
|
event: str,
|
||||||
|
prefix: str = None,
|
||||||
|
date_format: str = "%Y-%m-%d",
|
||||||
|
dt: Union[str,datetime] = None,
|
||||||
|
part: int = None,
|
||||||
|
extension: str = ".csv",
|
||||||
|
is_year: bool=False,
|
||||||
|
is_month: bool=False,
|
||||||
|
is_day: bool=False,
|
||||||
|
is_hour: bool=False,
|
||||||
|
is_update: bool=False,
|
||||||
|
) -> str:
|
||||||
|
"""CSVデータを文字列として生成"""
|
||||||
|
if not records:
|
||||||
|
logger.warning("No records to write.")
|
||||||
|
return ""
|
||||||
|
folder_path = cls.get_filepath(
|
||||||
|
domain=domain,
|
||||||
|
layer=layer,
|
||||||
|
is_year=is_year,
|
||||||
|
is_month=is_month,
|
||||||
|
is_day=is_day,
|
||||||
|
is_hour=is_hour,
|
||||||
|
dt=dt
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = cls.get_filename(
|
||||||
|
event=event,
|
||||||
|
prefix=prefix,
|
||||||
|
date_format=date_format,
|
||||||
|
dt=dt,
|
||||||
|
part=part,
|
||||||
|
extension=extension)
|
||||||
|
|
||||||
|
os.makedirs(folder_path, exist_ok=True)
|
||||||
|
full_filename = os.path.join(folder_path, filename)
|
||||||
|
|
||||||
|
if not is_update and os.path.exists(full_filename):
|
||||||
|
logger.info(f"File already exists and will not be overwritten: {full_filename}")
|
||||||
|
return full_filename
|
||||||
|
|
||||||
|
with open(full_filename, mode="w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
|
||||||
|
writer.writerows(records)
|
||||||
|
|
||||||
|
return full_filename
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write_with_filename(
|
||||||
|
cls,
|
||||||
|
records:List,
|
||||||
|
filename: str,
|
||||||
|
is_update: bool=False,
|
||||||
|
) -> str:
|
||||||
|
"""CSVデータを指定されたファイルパスに書き込む"""
|
||||||
|
if not records:
|
||||||
|
logger.warning("No records to write.")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
|
||||||
|
if not is_update and os.path.exists(filename):
|
||||||
|
logger.info(f"File already exists and will not be overwritten: {filename}")
|
||||||
|
return filename
|
||||||
|
|
||||||
|
with open(filename, mode="w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
|
||||||
|
writer.writerows(records)
|
||||||
|
|
||||||
|
return filename
|
||||||
69
src/lib/rss_reader_client.py
Normal file
69
src/lib/rss_reader_client.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
import feedparser
|
||||||
|
from feedparser import FeedParserDict
|
||||||
|
from models.csv_rss_item import RSSItem
|
||||||
|
from typing import Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Feed:
|
||||||
|
"""RSSフィード情報"""
|
||||||
|
url: str
|
||||||
|
title: str = ""
|
||||||
|
company: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RSSReaderClient:
|
||||||
|
"""RSSリーダークライアント"""
|
||||||
|
@classmethod
|
||||||
|
def fetch(
|
||||||
|
cls,
|
||||||
|
url: Union[str,Feed],
|
||||||
|
from_at: Union[str,datetime] = None,
|
||||||
|
to_at: Union[str,datetime] = None
|
||||||
|
) -> list[RSSItem]:
|
||||||
|
"""指定されたフィードから記事を取得する"""
|
||||||
|
items = []
|
||||||
|
url = url.url if isinstance(url, Feed) else url
|
||||||
|
d: FeedParserDict = feedparser.parse(url)
|
||||||
|
logger.info(f"Fetched {len(d.entries)} entries from {url}")
|
||||||
|
logger.debug(f"item {d.entries[0]}")
|
||||||
|
|
||||||
|
from_dt = cls._to_datetime(from_at)
|
||||||
|
to_dt = cls._to_datetime(to_at)
|
||||||
|
|
||||||
|
for e in d.entries:
|
||||||
|
item = RSSItem(
|
||||||
|
uid=e.get("id") or e.get("guid") or e.get("link"),
|
||||||
|
title=e.get("title", "(no title)"),
|
||||||
|
link=e.get("link"),
|
||||||
|
author=e.get("author"),
|
||||||
|
summary=e.get("summary") or e.get("description"),
|
||||||
|
published=e.get("published") or e.get("updated"),
|
||||||
|
published_parsed=e.get("published_parsed") or e.get("updated_parsed"),
|
||||||
|
)
|
||||||
|
if from_dt and item.published_parsed and item.published_parsed <= from_dt:
|
||||||
|
continue
|
||||||
|
if to_dt and item.published_parsed and item.published_parsed >= to_dt:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# logger.debug(f"Published at: {item.published_parsed}")
|
||||||
|
# logger.debug(f"> from dt: {from_dt}")
|
||||||
|
# logger.debug(f"< to dt: {to_dt}")
|
||||||
|
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_datetime(v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
return v
|
||||||
|
# ISO8601や"YYYY-MM-DD"形式を想定
|
||||||
|
return datetime.fromisoformat(v)
|
||||||
42
src/models/csv_model_base.py
Normal file
42
src/models/csv_model_base.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
from typing import ClassVar, Optional, List
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class CSVBaseModel(BaseModel):
|
||||||
|
"""BaseModelにCSV用の共通機能を追加した基底クラス"""
|
||||||
|
# クラスごとに除外設定を持てるようにする
|
||||||
|
csv_excludes: ClassVar[List[str]] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_headers(cls, excepts: Optional[List[str]] = None) -> List[str]:
|
||||||
|
"""CSVヘッダーを自動生成"""
|
||||||
|
fields = list(cls.model_fields.keys()) # 定義順を保持
|
||||||
|
if excepts:
|
||||||
|
fields = [f for f in fields if f not in excepts]
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def to_row(self, excepts: Optional[List[str]] = None) -> List[str]:
|
||||||
|
"""インスタンスをCSV行データに変換"""
|
||||||
|
header = self.to_headers(excepts=excepts)
|
||||||
|
row = []
|
||||||
|
for f in header:
|
||||||
|
val = getattr(self, f)
|
||||||
|
if isinstance(val, (dict, list)):
|
||||||
|
row.append(json.dumps(val, ensure_ascii=False)) # dictやlistはJSON文字列に
|
||||||
|
elif isinstance(val, datetime):
|
||||||
|
row.append(val.isoformat()) # datetimeはISO8601文字列に
|
||||||
|
elif val is None:
|
||||||
|
row.append("")
|
||||||
|
else:
|
||||||
|
row.append(str(val))
|
||||||
|
return row
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_csv_from_items(items: List['CSVBaseModel']) -> List:
|
||||||
|
"""CSV行データをまとめて取得"""
|
||||||
|
if not items:
|
||||||
|
return ""
|
||||||
|
headers = items[0].to_headers()
|
||||||
|
rows = [item.to_row() for item in items]
|
||||||
|
return [headers] + rows
|
||||||
35
src/models/csv_rss_item.py
Normal file
35
src/models/csv_rss_item.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
from .csv_model_base import CSVBaseModel
|
||||||
|
from pydantic import field_validator
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime , timezone
|
||||||
|
from email.utils import parsedate_to_datetime # RFC822系に強い
|
||||||
|
import calendar
|
||||||
|
|
||||||
|
class RSSItem(CSVBaseModel):
|
||||||
|
uid: str
|
||||||
|
title: str
|
||||||
|
link: str
|
||||||
|
author: Optional[str] = None
|
||||||
|
summary: Optional[str] = None
|
||||||
|
published: Optional[str] = None
|
||||||
|
published_parsed: Optional[datetime] = None
|
||||||
|
|
||||||
|
@field_validator("published_parsed", mode="before")
|
||||||
|
def parse_published(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
return v
|
||||||
|
if isinstance(v, time.struct_time):
|
||||||
|
# struct_time は基本UTC想定で calendar.timegm を使うとズレない
|
||||||
|
return datetime.fromtimestamp(calendar.timegm(v), tz=timezone.utc)
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
dt = parsedate_to_datetime(v)
|
||||||
|
return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return v
|
||||||
20
src/models/csv_scrape_item.py
Normal file
20
src/models/csv_scrape_item.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
|
||||||
|
from .csv_model_base import CSVBaseModel
|
||||||
|
from pydantic import field_validator
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime , timezone
|
||||||
|
from email.utils import parsedate_to_datetime # RFC822系に強い
|
||||||
|
import calendar
|
||||||
|
|
||||||
|
class ScrapeItem(CSVBaseModel):
|
||||||
|
uid: str
|
||||||
|
title: str
|
||||||
|
link: str
|
||||||
|
catgory: Optional[str] = None
|
||||||
|
summary: Optional[str] = None
|
||||||
|
published: Optional[str] = None
|
||||||
|
published_parsed: Optional[datetime] = None
|
||||||
|
detail_error: Optional[str] = None
|
||||||
|
item_info: Optional[dict] = None
|
||||||
|
|
||||||
10
src/pipe_context.py
Normal file
10
src/pipe_context.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
class PipeContext:
|
||||||
|
"""パイプラインのコンテキスト情報を管理するクラス"""
|
||||||
|
def __init__(self):
|
||||||
|
self.context = {}
|
||||||
|
|
||||||
|
def set(self, key: str, value: any):
|
||||||
|
self.context[key] = value
|
||||||
|
|
||||||
|
def get(self, key: str, default: any = None) -> any:
|
||||||
|
return self.context.get(key, default)
|
||||||
19
src/pipeline/pipeline_base.py
Normal file
19
src/pipeline/pipeline_base.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from typing import List
|
||||||
|
from jobs.job_base import JobBase
|
||||||
|
from pipe_context import PipeContext
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
class PipelineBase:
|
||||||
|
"""Pipelineの基本クラス"""
|
||||||
|
def __init__(self):
|
||||||
|
self.jobs:List[JobBase] = []
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.context = PipeContext()
|
||||||
|
|
||||||
|
def add_job(self, job: JobBase):
|
||||||
|
self.jobs.append(job)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
for job in self.jobs:
|
||||||
|
job.execute()
|
||||||
35
src/providers/duck_db_provider.py
Normal file
35
src/providers/duck_db_provider.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import duckdb
|
||||||
|
|
||||||
|
class DuckDBProvider:
|
||||||
|
def __init__(self, db_path: str = ":memory:", read_only: bool = False):
|
||||||
|
self.con = self.connect(db_path, read_only)
|
||||||
|
|
||||||
|
def connect(self, db_path: str = ":memory:", read_only: bool = False):
|
||||||
|
return duckdb.connect(database=db_path, read_only=read_only)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""接続を閉じる"""
|
||||||
|
if self.con:
|
||||||
|
self.con.close()
|
||||||
|
|
||||||
|
def query_df(self, sql: str):
|
||||||
|
"""SQLクエリを実行してDataFrameで返す"""
|
||||||
|
return self.con.execute(sql).df()
|
||||||
|
|
||||||
|
def max_value(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
column: str,
|
||||||
|
hive_partitioning: bool = True,
|
||||||
|
union_by_name: bool = True,
|
||||||
|
) -> any:
|
||||||
|
"""CSVファイルの指定列の最大値を取得する"""
|
||||||
|
query = f"""
|
||||||
|
SELECT MAX({column}) AS max_{column}
|
||||||
|
FROM read_csv_auto('{file_path}',
|
||||||
|
hive_partitioning={1 if hive_partitioning else 0},
|
||||||
|
union_by_name={1 if union_by_name else 0}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
result = self.con.execute(query).fetchone()[0]
|
||||||
|
return result
|
||||||
33
src/providers/rss/rss_openai_provider.py
Normal file
33
src/providers/rss/rss_openai_provider.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from lib.rss_reader_client import RSSReaderClient,RSSItem,Feed
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
|
||||||
|
class RSSOpenAIProvider:
|
||||||
|
"""RSS OpenAI プロバイダー"""
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.feeds = [
|
||||||
|
Feed(
|
||||||
|
url="https://openai.com/news/rss.xml",
|
||||||
|
title="Open AI News",
|
||||||
|
company="OpenAI",
|
||||||
|
language="en",
|
||||||
|
tags=["ai","openai", "news","llm"]
|
||||||
|
),
|
||||||
|
Feed(
|
||||||
|
url="https://openai.com/blog/rss.xml",
|
||||||
|
title="Open AI Blog",
|
||||||
|
company="OpenAI",
|
||||||
|
language="en",
|
||||||
|
tags=["ai","openai", "blog"]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def fetch(self) -> list[Feed]:
|
||||||
|
"""フィードから記事を取得する"""
|
||||||
|
result = []
|
||||||
|
for feed in self.feeds:
|
||||||
|
feed_items = RSSReaderClient.fetch(feed)
|
||||||
|
feed.feed_items = feed_items
|
||||||
|
|
||||||
|
return self.feeds
|
||||||
|
|
||||||
105
src/providers/scraper/anthropic_scraper_provider.py
Normal file
105
src/providers/scraper/anthropic_scraper_provider.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from .http_scraper_base import HttpScraperBase,ScrapeItem,_try_parse_date, urljoin
|
||||||
|
from selectolax.parser import HTMLParser
|
||||||
|
|
||||||
|
class AnthropicScraperProvider(HttpScraperBase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
base_url="https://www.anthropic.com"
|
||||||
|
start_url="https://www.anthropic.com/news"
|
||||||
|
super().__init__(
|
||||||
|
base_url=base_url,
|
||||||
|
start_url=start_url,
|
||||||
|
cb_parse_list=self.parse_list,
|
||||||
|
# cb_parse_detail=self.parse_detail
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_list(self, tree: HTMLParser) -> list[ScrapeItem]:
|
||||||
|
"""リストのパース処理"""
|
||||||
|
self.logger.info("Parsing list")
|
||||||
|
items = []
|
||||||
|
for node in tree.css('[class^="CardSpotlight_spotlightCard"]'):
|
||||||
|
href = node.attrs.get("href")
|
||||||
|
if not href:
|
||||||
|
continue
|
||||||
|
url = urljoin(self.base_url, href)
|
||||||
|
# タイトル
|
||||||
|
title_node = node.css_first("h2, h3, .title, span")
|
||||||
|
title = title_node.text(strip=True) if title_node else node.text(strip=True)
|
||||||
|
# category_node = node.css_first("p.detail-m:nth-of-type(1)")
|
||||||
|
# category = category_node.text(strip=True) if category_node else ""
|
||||||
|
# published_node = node.css_first(".detail-m.agate")
|
||||||
|
# published = published_node.text(strip=True) if published_node else ""
|
||||||
|
detail_nodes = node.css("p.detail-m")
|
||||||
|
if len(detail_nodes) >= 2:
|
||||||
|
category = detail_nodes[0].text(strip=True)
|
||||||
|
published = detail_nodes[1].text(strip=True)
|
||||||
|
published_parsed = _try_parse_date(published)
|
||||||
|
self.logger.debug(f"Found URL: {url} Title: {title[:10]}")
|
||||||
|
item = ScrapeItem(
|
||||||
|
uid=url,
|
||||||
|
link=url,
|
||||||
|
title=title,
|
||||||
|
category=category,
|
||||||
|
published=published,
|
||||||
|
published_parsed=published_parsed,
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for node in tree.css('[class^="Card_linkRoot"]'):
|
||||||
|
href = node.attrs.get("href")
|
||||||
|
if not href:
|
||||||
|
continue
|
||||||
|
url = urljoin(self.base_url, href)
|
||||||
|
# タイトル
|
||||||
|
title_node = node.css_first("h2, h3, .title, span")
|
||||||
|
title = title_node.text(strip=True) if title_node else node.text(strip=True)
|
||||||
|
category_node = node.css_first(".detail-m")
|
||||||
|
category = category_node.text(strip=True) if category_node else ""
|
||||||
|
published_node = node.css_first(".detail-m.agate")
|
||||||
|
published = published_node.text(strip=True) if published_node else ""
|
||||||
|
published_parsed = _try_parse_date(published)
|
||||||
|
self.logger.debug(f"Found URL: {url} Title: {title[:10]}")
|
||||||
|
item = ScrapeItem(
|
||||||
|
uid=url,
|
||||||
|
link=url,
|
||||||
|
title=title,
|
||||||
|
category=category,
|
||||||
|
published=published,
|
||||||
|
published_parsed=published_parsed,
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
for node in tree.css('[class*="PostList_post-card"]'):
|
||||||
|
href = node.attrs.get("href")
|
||||||
|
if not href:
|
||||||
|
continue
|
||||||
|
url = urljoin(self.base_url, href)
|
||||||
|
# タイトル
|
||||||
|
title_node = node.css_first("h2, h3, .title, span")
|
||||||
|
title = title_node.text(strip=True) if title_node else node.text(strip=True)
|
||||||
|
category_node = node.css_first('[class*="category_node"]')
|
||||||
|
category = category_node.text(strip=True) if category_node else ""
|
||||||
|
published_node = node.css_first('[class*="PostList_post-date"]')
|
||||||
|
published = published_node.text(strip=True) if published_node else ""
|
||||||
|
published_parsed = _try_parse_date(published)
|
||||||
|
self.logger.debug(f"Found URL: {url} Title: {title[:10]}")
|
||||||
|
item = ScrapeItem(
|
||||||
|
uid=url,
|
||||||
|
link=url,
|
||||||
|
title=title,
|
||||||
|
category=category,
|
||||||
|
published=published,
|
||||||
|
published_parsed=published_parsed,
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def parse_detail(self, tree: HTMLParser,item:ScrapeItem):
|
||||||
|
"""詳細ページのパース処理"""
|
||||||
|
self.logger.info("Parsing detail")
|
||||||
|
# content_node = tree.css_first('article')
|
||||||
|
# if content_node:
|
||||||
|
# item.summary = content_node.text(strip=True)
|
||||||
142
src/providers/scraper/http_scraper_base.py
Normal file
142
src/providers/scraper/http_scraper_base.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Union,Callable
|
||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
from selectolax.parser import HTMLParser
|
||||||
|
from models.csv_scrape_item import ScrapeItem
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
|
||||||
|
# ---- 日付パース補助
|
||||||
|
@staticmethod
|
||||||
|
def _try_parse_date(s: str | None):
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
s = s.strip()
|
||||||
|
# よくある英語表記の例: "Mar 30, 2023"
|
||||||
|
for fmt in ("%b %d, %Y", "%B %d, %Y", "%Y-%m-%d"):
|
||||||
|
try:
|
||||||
|
return datetime.strptime(s, fmt).isoformat()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# どうしても無理ならそのまま返す or None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HttpScraperBase():
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
base_url:str,
|
||||||
|
start_url:str,
|
||||||
|
concurrency: int = 8,
|
||||||
|
min_delay=0.5,
|
||||||
|
max_delay=1.5,
|
||||||
|
cb_parse_list:Callable=None,
|
||||||
|
cb_purse_next_url:Callable=None,
|
||||||
|
cb_parse_detail:Callable=None,
|
||||||
|
):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.base_url = base_url
|
||||||
|
self.start_url = start_url
|
||||||
|
self.headers = {"user-agent": "NewsScraper/1.0"}
|
||||||
|
self.cb_parse_list = cb_parse_list
|
||||||
|
self.cb_purse_next_url = cb_purse_next_url
|
||||||
|
self.cb_parse_detail = cb_parse_detail
|
||||||
|
self.min_delay = min_delay
|
||||||
|
self.max_delay = max_delay
|
||||||
|
self.concurrency = concurrency
|
||||||
|
|
||||||
|
async def polite_wait(self):
|
||||||
|
await asyncio.sleep(random.uniform(self.min_delay, self.max_delay))
|
||||||
|
|
||||||
|
async def fetch_text(self,client: httpx.AsyncClient, url: str,max_retries:int=3) -> str:
|
||||||
|
"""指定したURLのHTMLを取得する"""
|
||||||
|
attempt = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self.polite_wait()
|
||||||
|
r = await client.get(url,headers = self.headers, timeout=30,follow_redirects=True)
|
||||||
|
if r.status_code == 429:
|
||||||
|
retry_after = r.headers.get("Retry-After")
|
||||||
|
if retry_after:
|
||||||
|
try:
|
||||||
|
wait = int(retry_after)
|
||||||
|
except ValueError:
|
||||||
|
wait = 5
|
||||||
|
else:
|
||||||
|
wait = min(60, (2 ** attempt) + random.uniform(0, 1))
|
||||||
|
attempt += 1
|
||||||
|
if attempt > max_retries:
|
||||||
|
r.raise_for_status()
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
continue
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.text
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
self.logger.warning(f"HTTP error fetching {url}: {e}")
|
||||||
|
attempt += 1
|
||||||
|
if attempt > max_retries:
|
||||||
|
raise
|
||||||
|
await asyncio.sleep(min(60, (2 ** attempt) + random.uniform(0, 1)))
|
||||||
|
|
||||||
|
async def _parse(self, html: str)-> tuple[list[ScrapeItem], str|None]:
|
||||||
|
"""HTMLをパースして、記事のリストと次のページのURLを取得する"""
|
||||||
|
self.logger.info("Parsing HTML")
|
||||||
|
tree = HTMLParser(html)
|
||||||
|
items = await self.cb_parse_list(tree)
|
||||||
|
next_url = self.purse_next_url(tree)
|
||||||
|
return items, next_url
|
||||||
|
|
||||||
|
|
||||||
|
def purse_next_url(self,tree: HTMLParser) -> Union[str,None]:
|
||||||
|
"""Nextページのリンクを取得する(上書きして使う)"""
|
||||||
|
if self.cb_purse_next_url:
|
||||||
|
return self.cb_purse_next_url(tree)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def enrich_with_details(self, items: list[ScrapeItem]):
|
||||||
|
self.logger.info("Enriching items with details")
|
||||||
|
# 同時に動かすタスク数を制御する()
|
||||||
|
# Semaphore(セマフォ)**は「同時に処理していい数」のカウンターを持っていて、その数を超えると待機させる仕組み
|
||||||
|
sem = asyncio.Semaphore(self.concurrency)
|
||||||
|
async def fetch_and_parse(client: httpx.AsyncClient, it:ScrapeItem ):
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
self.logger.info(f"Fetching detail for {it.link}")
|
||||||
|
html = await self.fetch_text(client, it.link)
|
||||||
|
tree = HTMLParser(html)
|
||||||
|
self.cb_parse_detail(tree,it)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"Error fetching detail for {it.link}: {e}")
|
||||||
|
it.detail_error = str(e)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(http2=True, headers=self.headers) as client:
|
||||||
|
await asyncio.gather(*(fetch_and_parse(client, it) for it in items))
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
async def crawl(self):
|
||||||
|
# ページを取得する
|
||||||
|
results = []
|
||||||
|
self.logger.info("async crawl started")
|
||||||
|
async with httpx.AsyncClient(http2=True, headers=self.headers) as client:
|
||||||
|
url = self.start_url
|
||||||
|
while url:
|
||||||
|
html = await self.fetch_text(client, url)
|
||||||
|
self.logger.info(f"Fetched {url} (length: {len(html)})")
|
||||||
|
# HTMLをパースする
|
||||||
|
items, next_url = await self._parse(html)
|
||||||
|
if items and self.cb_parse_detail:
|
||||||
|
await self.enrich_with_details(items)
|
||||||
|
results.extend(items)
|
||||||
|
url = next_url
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def crawl_sync(self):
|
||||||
|
"""同期的にクロールを実行するメソッド"""
|
||||||
|
return asyncio.run(self.crawl())
|
||||||
190
src/providers/sns/api_sns_x.py
Normal file
190
src/providers/sns/api_sns_x.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import os
|
||||||
|
from requests_oauthlib import OAuth1
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class APISNSX:
|
||||||
|
|
||||||
|
"""X (formerly Twitter) API interaction class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
X_API_KEY = os.getenv("X_API_KEY")
|
||||||
|
X_API_KEY_SECRET = os.getenv("X_API_KEY_SECRET")
|
||||||
|
X_ACCESS_TOKEN = os.getenv("X_ACCESS_TOKEN")
|
||||||
|
X_ACCESS_TOKEN_SECRET = os.getenv("X_ACCESS_TOKEN_SECRET")
|
||||||
|
# Bearer Tokenは読み取り専用
|
||||||
|
X_BEARER_TOKEN = os.getenv("X_BEARER_TOKEN")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post(
|
||||||
|
cls,
|
||||||
|
content: str,
|
||||||
|
reply:object=None,
|
||||||
|
quote_tweet_id: str=None,
|
||||||
|
poll:object=None,
|
||||||
|
media:object=None
|
||||||
|
):
|
||||||
|
"""Xに投稿する
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (str): メッセージ内容
|
||||||
|
reply (object): 返信先のツイートオブジェクト
|
||||||
|
- 例) "replay" :{"in_reply_to_tweet_id": "1234567890123456789"}
|
||||||
|
quote_tweet_id (str): 引用リツイートするツイートID
|
||||||
|
poll (object): 投票オプション
|
||||||
|
- 例) "poll": {"options": ["Python", "JavaScript"], "duration_minutes": 60}
|
||||||
|
media (object): メディアオブジェクト
|
||||||
|
- 例) "media": {"media_ids": ["123456789012345678"]}
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- 権限が不足していると、403エラーが発生します。
|
||||||
|
- Read and Writeの権限が必要です。
|
||||||
|
- contentは280文字以内にしてください。
|
||||||
|
"""
|
||||||
|
logger.info(f"post to X: {content[:15]}...")
|
||||||
|
|
||||||
|
if len(content) > 280:
|
||||||
|
raise ValueError("Content exceeds 280 characters.")
|
||||||
|
|
||||||
|
if not all([cls.X_API_KEY, cls.X_API_KEY_SECRET, cls.X_ACCESS_TOKEN, cls.X_ACCESS_TOKEN_SECRET]):
|
||||||
|
raise ValueError("API keys and tokens must be set in environment variables.")
|
||||||
|
|
||||||
|
|
||||||
|
url = "https://api.twitter.com/2/tweets"
|
||||||
|
auth = OAuth1(
|
||||||
|
cls.X_API_KEY,
|
||||||
|
cls.X_API_KEY_SECRET,
|
||||||
|
cls.X_ACCESS_TOKEN,
|
||||||
|
cls.X_ACCESS_TOKEN_SECRET,
|
||||||
|
)
|
||||||
|
payload = {"text": content}
|
||||||
|
if reply:
|
||||||
|
payload["reply"] = reply
|
||||||
|
if quote_tweet_id:
|
||||||
|
payload["quote_tweet_id"] = quote_tweet_id
|
||||||
|
if poll:
|
||||||
|
payload["poll"] = poll
|
||||||
|
if media:
|
||||||
|
payload["media"] = media
|
||||||
|
response = requests.post(url, auth=auth, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.info("Successfully posted to X.")
|
||||||
|
json_data = response.json()
|
||||||
|
logger.debug(f"Response: {json_data}")
|
||||||
|
return json_data
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _headers(cls):
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {cls.X_BEARER_TOKEN}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tweet(cls, tweet_id: str):
|
||||||
|
"""ツイートIDで取得
|
||||||
|
Args:
|
||||||
|
tweet_id (str): ツイートID
|
||||||
|
"""
|
||||||
|
logger.info(f"Get tweet by ID: {tweet_id}")
|
||||||
|
if not cls.X_BEARER_TOKEN:
|
||||||
|
raise ValueError("Bearer token must be set in environment variables.")
|
||||||
|
|
||||||
|
# ツイートに関して返して欲しい追加フィールドをカンマ区切りで列挙します
|
||||||
|
params = {
|
||||||
|
"tweet.fields": "created_at,author_id,public_metrics,conversation_id,referenced_tweets",
|
||||||
|
"expansions": "author_id", # author_id を展開して、ユーザー情報を includes.usersniに含める
|
||||||
|
"user.fields": "username,name,verified",
|
||||||
|
}
|
||||||
|
url = f"https://api.twitter.com/2/tweets/{tweet_id}"
|
||||||
|
headers = cls._headers()
|
||||||
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.debug(f"Get tweet response: {response.json()}")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_by_username(cls,username: str):
|
||||||
|
"""Get user information by username."""
|
||||||
|
logger.info(f"Get user by username: {username}")
|
||||||
|
if not cls.X_BEARER_TOKEN:
|
||||||
|
raise ValueError("Bearer token must be set in environment variables.")
|
||||||
|
params = {"user.fields": "name,username,verified,created_at"}
|
||||||
|
url = f"https://api.twitter.com/2/users/by/username/{username}"
|
||||||
|
headers = cls._headers()
|
||||||
|
response = requests.get(url, headers=headers , params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.debug(f"Get user response: {response.json()}")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_tweets(cls,user_id: str, max_results=10, pagination_token=None):
|
||||||
|
"""任意ユーザ(自分、任意ユーザ)のタイムライン取得(直近投稿)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ユーザーID
|
||||||
|
max_results (int): 取得するツイートの最大数 (5〜100)
|
||||||
|
pagination_token (str): 続きを取得するためのトークン
|
||||||
|
"""
|
||||||
|
logger.info(f"Get tweets for user ID: {user_id}")
|
||||||
|
if not APISNSX.X_BEARER_TOKEN:
|
||||||
|
raise ValueError("Bearer token must be set in environment variables.")
|
||||||
|
|
||||||
|
url = f"https://api.twitter.com/2/users/{user_id}/tweets"
|
||||||
|
params = {
|
||||||
|
"max_results": max_results, # 5〜100
|
||||||
|
"pagination_token": pagination_token, # 続きを取る時に指定
|
||||||
|
"tweet.fields": "created_at,public_metrics,conversation_id,referenced_tweets",
|
||||||
|
"expansions": "referenced_tweets.id",
|
||||||
|
}
|
||||||
|
# None値は送らない
|
||||||
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
headers = cls._headers()
|
||||||
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def search_recent_tweets(cls, query: str, max_results=10, next_token=None):
|
||||||
|
"""最近のツイートを検索する
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): 検索クエリ
|
||||||
|
max_results (int): 取得するツイートの最大数 (10〜100)
|
||||||
|
next_token (str): 続きを取得するためのトークン
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- 検索クエリ
|
||||||
|
- 特定ユーザーの投稿のみ: from:elonmusk
|
||||||
|
- 特定ユーザーへの返信のみ: to:elonmusk
|
||||||
|
- あるユーザーが含まれる会話: @OpenAI
|
||||||
|
- lang:ja 日本語と判定されたツイートだけ取得
|
||||||
|
- 除外: -is:retweet (リツイートを除外), -is:reply (返信を除外)
|
||||||
|
- ユーザー+キーワード: from:OpenAI langchain
|
||||||
|
- 無料やBasicプランでは、/2/tweets/search/recent の呼び出しに制限があります。
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.info(f"Search recent tweets with query: {query}")
|
||||||
|
if not cls.X_BEARER_TOKEN:
|
||||||
|
raise ValueError("Bearer token must be set in environment variables.")
|
||||||
|
|
||||||
|
url = "https://api.twitter.com/2/tweets/search/recent"
|
||||||
|
params = {
|
||||||
|
"query": query, # 例: "AI langchain -is:retweet"
|
||||||
|
"max_results": max_results, # 10〜100
|
||||||
|
"next_token": next_token, # 続きを取る時に指定
|
||||||
|
"tweet.fields": "created_at,author_id,public_metrics,conversation_id,referenced_tweets",
|
||||||
|
"expansions": "author_id",
|
||||||
|
"user.fields": "username,name,verified",
|
||||||
|
}
|
||||||
|
# None値は送らない
|
||||||
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
headers = cls._headers()
|
||||||
|
response = requests.get(url, headers=headers, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
249
src/providers/sns/x_sns_scraper.py
Normal file
249
src/providers/sns/x_sns_scraper.py
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from playwright.async_api import async_playwright, Browser, BrowserContext, Page, TimeoutError
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
# TWEET_RX = re.compile(r"/i/api/graphql/.+/(TweetDetail|TweetResultByRestId|ConversationTimeline)")
|
||||||
|
TWEET_RX = re.compile(r"/i/api/graphql/.+/(TweetDetail|TweetResultByRestId|ConversationTimeline|SearchTimeline)")
|
||||||
|
|
||||||
|
def _sg(d, path, default=None):
|
||||||
|
cur = d
|
||||||
|
for p in path:
|
||||||
|
if isinstance(cur, dict) and p in cur:
|
||||||
|
cur = cur[p]
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
return cur
|
||||||
|
|
||||||
|
def _emit_from_node(node):
|
||||||
|
# デバッグ用に中身を全部出す
|
||||||
|
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
legacy = _sg(node, ["legacy"], {})
|
||||||
|
user_result = _sg(node, ["core", "user_results", "result"]) or {}
|
||||||
|
# ★ ここを修正:legacy だけでなく core も見る
|
||||||
|
username = (
|
||||||
|
_sg(user_result, ["legacy", "screen_name"])
|
||||||
|
or _sg(user_result, ["core", "screen_name"]) # ← 追加
|
||||||
|
)
|
||||||
|
name = (
|
||||||
|
_sg(user_result, ["legacy", "name"])
|
||||||
|
or _sg(user_result, ["core", "name"]) # ← 追加
|
||||||
|
)
|
||||||
|
|
||||||
|
# さらに保険:author_results 側にも同様の分岐
|
||||||
|
if not username or not name:
|
||||||
|
author = _sg(node, ["author_results", "result"]) or {}
|
||||||
|
username = username or _sg(author, ["legacy", "screen_name"]) or _sg(author, ["core", "screen_name"])
|
||||||
|
name = name or _sg(author, ["legacy", "name"]) or _sg(author, ["core", "name"])
|
||||||
|
|
||||||
|
tid = node.get("rest_id") or legacy.get("id_str")
|
||||||
|
if not tid:
|
||||||
|
return None
|
||||||
|
|
||||||
|
public_metrics = {
|
||||||
|
"retweet_count": legacy.get("retweet_count", 0),
|
||||||
|
"reply_count": legacy.get("reply_count", 0),
|
||||||
|
"like_count": legacy.get("favorite_count", 0),
|
||||||
|
"quote_count": legacy.get("quote_count", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": tid,
|
||||||
|
"text": legacy.get("full_text") or legacy.get("text"),
|
||||||
|
"created_at": legacy.get("created_at"),
|
||||||
|
"username": username,
|
||||||
|
"name": name,
|
||||||
|
"permalink": f"https://x.com/{username}/status/{tid}" if username else None,
|
||||||
|
"public_metrics": public_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _collect_graphql(page, sec=2.0):
|
||||||
|
buf = []
|
||||||
|
|
||||||
|
async def on_response(res):
|
||||||
|
try:
|
||||||
|
# Content-Type 判定は .get(...) で安全に
|
||||||
|
if TWEET_RX.search(res.url) and "application/json" in (res.headers.get("content-type") or ""):
|
||||||
|
buf.append(await res.json())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
page.on("response", on_response)
|
||||||
|
try:
|
||||||
|
# 一定時間だけリッスン
|
||||||
|
await asyncio.sleep(sec)
|
||||||
|
finally:
|
||||||
|
# ★ Python Playwright は off が無いので remove_listener を使う
|
||||||
|
page.remove_listener("response", on_response)
|
||||||
|
|
||||||
|
return buf
|
||||||
|
|
||||||
|
def _extract(payload):
|
||||||
|
out = []
|
||||||
|
|
||||||
|
# 1) timeline 形式(Search/Conversation 等)
|
||||||
|
tl:dict = _sg(payload, ["data","search_by_raw_query","search_timeline","timeline"]) \
|
||||||
|
or _sg(payload, ["data","conversation_timeline","timeline"])
|
||||||
|
if tl:
|
||||||
|
for ins in tl.get("instructions", []):
|
||||||
|
entries = ins.get("entries") or _sg(ins, ["entry","content","items"]) or []
|
||||||
|
for ent in entries:
|
||||||
|
content = ent.get("content") or _sg(ent, ["item","itemContent"]) or {}
|
||||||
|
# 直下
|
||||||
|
r = _sg(content, ["itemContent","tweet_results","result"]) or _sg(content, ["tweet_results","result"])
|
||||||
|
if r:
|
||||||
|
t = _emit_from_node(r); t and out.append(t)
|
||||||
|
# 配列下
|
||||||
|
for it in content.get("items", []):
|
||||||
|
r2 = _sg(it, ["item","itemContent","tweet_results","result"])
|
||||||
|
if r2:
|
||||||
|
t = _emit_from_node(r2); t and out.append(t)
|
||||||
|
|
||||||
|
# 2) 単体 TweetDetail
|
||||||
|
r = _sg(payload, ["data","tweetResult","result"]) \
|
||||||
|
or _sg(payload, ["data","tweetResultByRestId","result"]) \
|
||||||
|
or _sg(payload, ["data","tweetresultbyrestid","result"])
|
||||||
|
if r:
|
||||||
|
t = _emit_from_node(r); t and out.append(t)
|
||||||
|
|
||||||
|
# dedup by id
|
||||||
|
m = {}
|
||||||
|
for t in out:
|
||||||
|
m[t["id"]] = t
|
||||||
|
return list(m.values())
|
||||||
|
|
||||||
|
|
||||||
|
async def _goto_and_scrape(page:Page, url, warm=1.5, shot=2.0):
|
||||||
|
await page.goto(url, wait_until="domcontentloaded")
|
||||||
|
await asyncio.sleep(warm)
|
||||||
|
payloads = await _collect_graphql(page, sec=shot)
|
||||||
|
items = []
|
||||||
|
for p in payloads:
|
||||||
|
items.extend(_extract(p))
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def _scroll_more(page, times=2, wait=1.0):
|
||||||
|
got = []
|
||||||
|
for _ in range(times):
|
||||||
|
fut = asyncio.create_task(_collect_graphql(page, sec=wait))
|
||||||
|
await page.evaluate("window.scrollBy(0, document.body.scrollHeight);")
|
||||||
|
payloads = await fut
|
||||||
|
for p in payloads:
|
||||||
|
got.extend(_extract(p))
|
||||||
|
# dedup
|
||||||
|
m = {t["id"]: t for t in got}
|
||||||
|
return list(m.values())
|
||||||
|
|
||||||
|
async def _fill_with_scroll(page, base_list, limit, tries=5):
|
||||||
|
items = {t["id"]: t for t in base_list}
|
||||||
|
k = lambda t: t.get("created_at") or ""
|
||||||
|
i = 0
|
||||||
|
while len(items) < limit and i < tries:
|
||||||
|
more = await _scroll_more(page, times=2, wait=1.0)
|
||||||
|
for t in more:
|
||||||
|
items[t["id"]] = t
|
||||||
|
i += 1
|
||||||
|
out = list(items.values()); out.sort(key=k, reverse=True)
|
||||||
|
return out[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
class XScraper:
|
||||||
|
"""
|
||||||
|
- 初回: login_manual() でログイン → save_state()
|
||||||
|
- 2回目以降: storage_state を読み込んで start() するだけ
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_state: str = "x_cookies.json",
|
||||||
|
headless: bool = True,
|
||||||
|
slow_mo: int = 0,
|
||||||
|
user_agent: Optional[str] = None,
|
||||||
|
locale: str = "ja-JP",
|
||||||
|
timezone_id: str = "Asia/Tokyo",
|
||||||
|
viewport: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
self.storage_state = storage_state
|
||||||
|
self.headless = headless
|
||||||
|
self.slow_mo = slow_mo
|
||||||
|
self.user_agent = user_agent
|
||||||
|
self.locale = locale
|
||||||
|
self.timezone_id = timezone_id
|
||||||
|
self.viewport = viewport or {"width": 1280, "height": 900}
|
||||||
|
|
||||||
|
self._p = None
|
||||||
|
self._browser: Optional[Browser] = None
|
||||||
|
self._ctx: Optional[BrowserContext] = None
|
||||||
|
self._page: Optional[Page] = None
|
||||||
|
|
||||||
|
# ---- lifecycle ----
|
||||||
|
async def start(self):
|
||||||
|
"""storage_state があれば読み込んで起動、なければ空の状態で起動"""
|
||||||
|
self._p = await async_playwright().start()
|
||||||
|
self._browser = await self._p.chromium.launch(headless=self.headless, slow_mo=self.slow_mo)
|
||||||
|
|
||||||
|
context_kwargs = dict(
|
||||||
|
locale=self.locale,
|
||||||
|
timezone_id=self.timezone_id,
|
||||||
|
viewport=self.viewport,
|
||||||
|
)
|
||||||
|
if self.user_agent:
|
||||||
|
context_kwargs["user_agent"] = self.user_agent
|
||||||
|
|
||||||
|
if Path(self.storage_state).exists():
|
||||||
|
context_kwargs["storage_state"] = self.storage_state
|
||||||
|
|
||||||
|
self._ctx = await self._browser.new_context(**context_kwargs)
|
||||||
|
self._page = await self._ctx.new_page()
|
||||||
|
|
||||||
|
async def stop(self, save_state: bool = False):
|
||||||
|
"""必要なら state を保存してから終了"""
|
||||||
|
if save_state and self._ctx:
|
||||||
|
await self._ctx.storage_state(path=self.storage_state)
|
||||||
|
if self._ctx:
|
||||||
|
await self._ctx.close()
|
||||||
|
if self._browser:
|
||||||
|
await self._browser.close()
|
||||||
|
if self._p:
|
||||||
|
await self._p.stop()
|
||||||
|
|
||||||
|
# ---- helpers ----
|
||||||
|
@property
|
||||||
|
def page(self) -> Page:
|
||||||
|
assert self._page is not None, "Call start() first"
|
||||||
|
return self._page
|
||||||
|
|
||||||
|
async def is_logged_in(self, timeout_ms: int = 6000) -> bool:
|
||||||
|
"""ホームでアカウントメニューが見えるかで判定"""
|
||||||
|
await self.page.goto("https://x.com/home", wait_until="domcontentloaded")
|
||||||
|
try:
|
||||||
|
await self.page.wait_for_selector(
|
||||||
|
'[aria-label="Account menu"], [data-testid="SideNav_AccountSwitcher_Button"]',
|
||||||
|
timeout=timeout_ms,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except TimeoutError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def login_manual(self):
|
||||||
|
"""手動ログイン用。呼び出し側で input()/sleep などで待機してね。"""
|
||||||
|
await self.page.goto("https://x.com/login", wait_until="domcontentloaded")
|
||||||
|
|
||||||
|
async def save_state(self):
|
||||||
|
"""現在のコンテキスト状態を保存"""
|
||||||
|
assert self._ctx is not None
|
||||||
|
await self._ctx.storage_state(path=self.storage_state)
|
||||||
|
|
||||||
|
# ---- example usage ----
|
||||||
|
|
||||||
|
async def search_tweets(self, query: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||||
|
q = quote(query, safe="")
|
||||||
|
url = f"https://x.com/search?q={q}&src=typed_query&f=live"
|
||||||
|
first = await _goto_and_scrape(self.page, url)
|
||||||
|
return await _fill_with_scroll(self.page, first, limit)
|
||||||
66
src/utils/translate_argos.py
Normal file
66
src/utils/translate_argos.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# argos-translate --from-lang ja --to-lang en install
|
||||||
|
# argos-translate --from-lang en --to-lang ja install
|
||||||
|
import re
|
||||||
|
import argostranslate.package
|
||||||
|
import argostranslate.translate
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
|
||||||
|
class ArgosTranslateClient():
|
||||||
|
SUPPORT_LANG = [
|
||||||
|
("ja","en"),
|
||||||
|
("en","ja"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_text(s: str) -> str:
|
||||||
|
# 句読点前後の余分な空白を整理(誤訳を減らすため軽く前処理)
|
||||||
|
s = " ".join(s.split())
|
||||||
|
s = re.sub(r"\s+([,.:;!?])", r"\1", s)
|
||||||
|
s = re.sub(r"([(\[“‘'\")\])])\s+", r"\1 ", s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""Argos Translate クライアント"""
|
||||||
|
def __init__(self,init_install=False):
|
||||||
|
self.logger = get_logger()
|
||||||
|
if init_install:
|
||||||
|
self.install_models()
|
||||||
|
|
||||||
|
|
||||||
|
def install_models(self):
|
||||||
|
"""サポートされている言語ペアの翻訳モデルをインストールする"""
|
||||||
|
self.logger.info("Installing translation models...")
|
||||||
|
installed_languages = argostranslate.translate.get_installed_languages()
|
||||||
|
installed_codes = {lang.code for lang in installed_languages}
|
||||||
|
|
||||||
|
for from_lang, to_lang in self.SUPPORT_LANG:
|
||||||
|
if from_lang in installed_codes and to_lang in installed_codes:
|
||||||
|
self.logger.info(f"Translation model for {from_lang} to {to_lang} is already installed.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
available_packages = argostranslate.package.get_available_packages()
|
||||||
|
package_to_install = next(
|
||||||
|
(pkg for pkg in available_packages if pkg.from_code == from_lang and pkg.to_code == to_lang),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if package_to_install:
|
||||||
|
self.logger.info(f"Installing package: {package_to_install}")
|
||||||
|
argostranslate.package.install_from_path(package_to_install.download())
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"No available package found for {from_lang} to {to_lang}")
|
||||||
|
|
||||||
|
def translate(self, text, from_lang, to_lang):
|
||||||
|
"""テキストを翻訳する"""
|
||||||
|
text = self._normalize_text(text)
|
||||||
|
return argostranslate.translate.translate(text, from_lang, to_lang)
|
||||||
|
|
||||||
|
def list_installed_languages(self):
|
||||||
|
"""インストールされている翻訳モデルをリストする"""
|
||||||
|
plgs = argostranslate.package.get_installed_packages()
|
||||||
|
ret = []
|
||||||
|
for p in plgs:
|
||||||
|
self.logger.debug(f"{p.from_code} -> {p.to_code} | {getattr(p, 'version', '?')}")
|
||||||
|
ret.append((p.from_code, p.to_code,getattr(p, 'version', 'None')))
|
||||||
|
return ret
|
||||||
35
src/utils/translate_deepl.py
Normal file
35
src/utils/translate_deepl.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import os
|
||||||
|
import deepl
|
||||||
|
|
||||||
|
from lib.custom_logger import get_logger
|
||||||
|
|
||||||
|
class DeepLTranslateClient():
|
||||||
|
"""DeepL翻訳クライアント"""
|
||||||
|
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY", "")
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.api_key = api_key or os.getenv("DEEPL_API_KEY", "")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("DeepL API key が設定されていません。環境変数 DEEPL_API_KEY をセットしてください。")
|
||||||
|
|
||||||
|
# Translator インスタンス生成
|
||||||
|
self.translator = deepl.Translator(self.api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def translate(self, text: str, from_lang: str, to_lang: str) -> str:
|
||||||
|
"""
|
||||||
|
テキストを翻訳する
|
||||||
|
:param text: 翻訳対象文字列
|
||||||
|
:param from_lang: 入力言語 (例: 'EN', 'JA')
|
||||||
|
:param to_lang: 出力言語 (例: 'JA', 'EN')
|
||||||
|
:return: 翻訳後テキスト
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
result = self.translator.translate_text(
|
||||||
|
text,
|
||||||
|
source_lang=from_lang.upper(),
|
||||||
|
target_lang=to_lang.upper(),
|
||||||
|
)
|
||||||
|
return result.text
|
||||||
90
src/utils/translate_marian_mt.py
Normal file
90
src/utils/translate_marian_mt.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Iterable, List, Tuple, Dict
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from transformers import MarianMTModel, MarianTokenizer
|
||||||
|
|
||||||
|
def _norm_lang(code: str) -> str:
|
||||||
|
aliases = {"jp": "ja", "ja-jp": "ja", "en-us": "en", "en-gb": "en"}
|
||||||
|
c = code.lower().strip()
|
||||||
|
return aliases.get(c, c.split("-")[0])
|
||||||
|
|
||||||
|
|
||||||
|
class MarianMTClient():
|
||||||
|
|
||||||
|
# 言語ペア→モデル名(必要に応じて追加可)
|
||||||
|
MODEL_MAP: Dict[Tuple[str, str], str] = {
|
||||||
|
("en", "ja"): "staka/fugumt-en-ja",
|
||||||
|
("ja", "en"): "staka/fugumt-ja-en",
|
||||||
|
# ("en", "ja"): "Helsinki-NLP/opus-mt-en-jap",
|
||||||
|
# ("ja", "en"): "Helsinki-NLP/opus-mt-ja-en",
|
||||||
|
}
|
||||||
|
# https://huggingface.co/Helsinki-NLP/opus-mt-ja-en
|
||||||
|
# https://huggingface.co/Helsinki-NLP/opus-mt-en-ja
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_text(s: str) -> str:
|
||||||
|
# 句読点前後の余分な空白を整理(誤訳を減らすため軽く前処理)
|
||||||
|
s = " ".join(s.split())
|
||||||
|
s = re.sub(r"\s+([,.:;!?])", r"\1", s)
|
||||||
|
s = re.sub(r"([(\[“‘'\")\])])\s+", r"\1 ", s)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pairs: Iterable[Tuple[str, str]] = (("en", "ja"), ("ja", "en")),
|
||||||
|
device: str | None = None, # "cpu" / "cuda" / None(自動)
|
||||||
|
num_beams: int = 4, # 品質重視(速度を上げたいときは 1〜2)
|
||||||
|
max_new_tokens: int = 256,
|
||||||
|
no_repeat_ngram_size: int = 3,
|
||||||
|
):
|
||||||
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.gen_kwargs = dict(
|
||||||
|
num_beams=num_beams,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
|
)
|
||||||
|
self._tok: Dict[str, MarianTokenizer] = {}
|
||||||
|
self._mdl: Dict[str, MarianMTModel] = {}
|
||||||
|
|
||||||
|
# 事前に必要ペアをロード
|
||||||
|
for f, t in pairs:
|
||||||
|
self._ensure_loaded(_norm_lang(f), _norm_lang(t))
|
||||||
|
|
||||||
|
# CPU最適化(任意)
|
||||||
|
if self.device == "cpu":
|
||||||
|
torch.set_num_threads(max(1, torch.get_num_threads()))
|
||||||
|
|
||||||
|
|
||||||
|
# 内部:モデル読み込み(キャッシュ)
|
||||||
|
def _ensure_loaded(self, f: str, t: str) -> str:
|
||||||
|
key = (f, t)
|
||||||
|
model_id = self.MODEL_MAP.get(key)
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError(f"No Marian model mapping for {f}->{t}. Add it to MODEL_MAP.")
|
||||||
|
|
||||||
|
if model_id in self._mdl:
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
tok = MarianTokenizer.from_pretrained(model_id)
|
||||||
|
mdl = MarianMTModel.from_pretrained(model_id)
|
||||||
|
mdl.to(self.device).eval()
|
||||||
|
|
||||||
|
self._tok[model_id] = tok
|
||||||
|
self._mdl[model_id] = mdl
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
def translate(self, text: str, from_lang: str, to_lang: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
f, t = _norm_lang(from_lang), _norm_lang(to_lang)
|
||||||
|
model_id = self._ensure_loaded(f, t)
|
||||||
|
tok, mdl = self._tok[model_id], self._mdl[model_id]
|
||||||
|
|
||||||
|
s = self._normalize_text(text)
|
||||||
|
with torch.no_grad():
|
||||||
|
batch = tok([s], return_tensors="pt")
|
||||||
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
||||||
|
out = mdl.generate(**batch, **self.gen_kwargs)
|
||||||
|
return tok.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
||||||
6
src/utils/types.py
Normal file
6
src/utils/types.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class DataLayer(str, Enum):
|
||||||
|
BRONZE = "bronze"
|
||||||
|
SILVER = "silver"
|
||||||
|
GOLD = "gold"
|
||||||
Loading…
x
Reference in New Issue
Block a user