phase 14-15: docker build validation and helm deployment
This commit is contained in:
@@ -1 +1,45 @@
|
||||
# Ingestion Adapters
|
||||
from .base import AdapterResult, BaseAdapter
|
||||
from .resilient import ResilientAdapter, RetryConfig, RetryStats, compute_delay
|
||||
from .broker_adapter import (
|
||||
AccountInfo,
|
||||
AlpacaBrokerAdapter,
|
||||
BrokerDataAdapter,
|
||||
OrderEventType,
|
||||
OrderRequest,
|
||||
OrderResponse,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
PositionInfo,
|
||||
TradingMode,
|
||||
)
|
||||
from .filings_adapter import FilingsDataAdapter, SECEdgarAdapter
|
||||
from .market_adapter import MarketDataAdapter, PolygonMarketAdapter
|
||||
from .news_adapter import NewsDataAdapter, PolygonNewsAdapter
|
||||
|
||||
__all__ = [
|
||||
"AccountInfo",
|
||||
"AdapterResult",
|
||||
"AlpacaBrokerAdapter",
|
||||
"BaseAdapter",
|
||||
"BrokerDataAdapter",
|
||||
"FilingsDataAdapter",
|
||||
"MarketDataAdapter",
|
||||
"NewsDataAdapter",
|
||||
"OrderEventType",
|
||||
"OrderRequest",
|
||||
"OrderResponse",
|
||||
"OrderSide",
|
||||
"OrderStatus",
|
||||
"OrderType",
|
||||
"PolygonMarketAdapter",
|
||||
"PolygonNewsAdapter",
|
||||
"PositionInfo",
|
||||
"ResilientAdapter",
|
||||
"RetryConfig",
|
||||
"RetryStats",
|
||||
"SECEdgarAdapter",
|
||||
"TradingMode",
|
||||
"compute_delay",
|
||||
]
|
||||
|
||||
@@ -1,29 +1,84 @@
|
||||
"""Base adapter interface for all external API integrations."""
|
||||
"""Base adapter interface for all external API integrations.
|
||||
|
||||
All ingestion adapters follow the same contract:
|
||||
1. Fetch external payloads for a given ticker/source config.
|
||||
2. Return a structured result with raw bytes, parsed items, and metadata.
|
||||
3. The ingestion worker handles MinIO upload, PostgreSQL metadata, and downstream job emission.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5, 3.1, 3.2, 3.3, 3.4
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterResult:
|
||||
"""Result of a single adapter fetch operation."""
|
||||
|
||||
source_type: str
|
||||
ticker: str
|
||||
items: List[Dict[str, Any]]
|
||||
items: list[dict[str, Any]]
|
||||
raw_payload: bytes
|
||||
content_hash: str
|
||||
fetched_at: datetime
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
# HTTP metadata for observability
|
||||
http_status: int | None = None
|
||||
response_time_ms: float | None = None
|
||||
# Additional metadata the adapter wants to pass downstream
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
"""True if the fetch succeeded without error."""
|
||||
return self.error is None and len(self.items) > 0
|
||||
|
||||
@property
|
||||
def item_count(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
"""Interface for all ingestion adapters."""
|
||||
"""Interface for all ingestion adapters.
|
||||
|
||||
Subclasses implement fetch() for their specific API and source_type()
|
||||
to identify the adapter class. The ingestion worker orchestrates
|
||||
persistence and downstream job emission.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch data for a given ticker using source config."""
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch data for a given ticker using source config.
|
||||
|
||||
Args:
|
||||
ticker: The company ticker symbol.
|
||||
config: Source-specific configuration from the sources table.
|
||||
|
||||
Returns:
|
||||
AdapterResult with raw payload, parsed items, and metadata.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def source_type(self) -> str:
|
||||
"""Return the source type identifier for this adapter (e.g. 'market_api')."""
|
||||
...
|
||||
|
||||
def bucket_name(self) -> str:
|
||||
"""Return the MinIO bucket name for raw artifact storage.
|
||||
|
||||
Override in subclasses if the bucket differs from the default pattern.
|
||||
"""
|
||||
return f"stonks-raw-{self.source_type().replace('_api', '').replace('_', '-')}"
|
||||
|
||||
def artifact_path(self, ticker: str, document_id: str, now: datetime) -> str:
|
||||
"""Build the MinIO object path for a raw artifact.
|
||||
|
||||
Pattern: /{source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/raw.json
|
||||
"""
|
||||
return (
|
||||
f"{self.source_type()}/{ticker}/"
|
||||
f"{now.strftime('%Y/%m/%d')}/{document_id}/raw.json"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
"""Broker API adapter - paper/live trading, orders, positions, balances."""
|
||||
"""Broker API adapter interface for paper trading and order events.
|
||||
|
||||
The BrokerDataAdapter is the abstract interface for all broker integrations.
|
||||
AlpacaBrokerAdapter is the first concrete implementation, targeting the
|
||||
Alpaca Markets REST API for paper and live trading.
|
||||
|
||||
Requirements: 2.4, 2.5, 8.1, 8.3, 8.5
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -12,97 +22,584 @@ from .base import AdapterResult, BaseAdapter
|
||||
logger = logging.getLogger("broker_adapter")
|
||||
|
||||
|
||||
class BrokerAdapter(BaseAdapter):
|
||||
"""Broker API adapter supporting paper and live modes."""
|
||||
# --- Broker-specific enums ---
|
||||
|
||||
def __init__(self, api_key: str = "", api_secret: str = "", base_url: str = "", mode: str = "paper"):
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.base_url = base_url
|
||||
self.mode = mode # paper | live
|
||||
|
||||
class OrderSide(str, Enum):
|
||||
BUY = "buy"
|
||||
SELL = "sell"
|
||||
|
||||
|
||||
class OrderType(str, Enum):
|
||||
MARKET = "market"
|
||||
LIMIT = "limit"
|
||||
STOP = "stop"
|
||||
STOP_LIMIT = "stop_limit"
|
||||
|
||||
|
||||
class OrderStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
SUBMITTED = "submitted"
|
||||
ACCEPTED = "accepted"
|
||||
PARTIALLY_FILLED = "partially_filled"
|
||||
FILLED = "filled"
|
||||
CANCELLED = "cancelled"
|
||||
REJECTED = "rejected"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class TradingMode(str, Enum):
|
||||
PAPER = "paper"
|
||||
LIVE = "live"
|
||||
|
||||
|
||||
class OrderEventType(str, Enum):
|
||||
SUBMITTED = "submitted"
|
||||
ACCEPTED = "accepted"
|
||||
REJECTED = "rejected"
|
||||
FILL = "fill"
|
||||
PARTIAL_FILL = "partial_fill"
|
||||
CANCELLED = "cancelled"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
# --- Data structures ---
|
||||
|
||||
|
||||
class OrderRequest:
|
||||
"""Represents an order to be submitted to a broker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ticker: str,
|
||||
side: OrderSide,
|
||||
quantity: float,
|
||||
order_type: OrderType = OrderType.MARKET,
|
||||
limit_price: float | None = None,
|
||||
stop_price: float | None = None,
|
||||
time_in_force: str = "day",
|
||||
idempotency_key: str | None = None,
|
||||
) -> None:
|
||||
self.ticker = ticker
|
||||
self.side = side
|
||||
self.quantity = quantity
|
||||
self.order_type = order_type
|
||||
self.limit_price = limit_price
|
||||
self.stop_price = stop_price
|
||||
self.time_in_force = time_in_force
|
||||
self.idempotency_key = idempotency_key or str(uuid.uuid4())
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to a dict for audit/persistence."""
|
||||
d: dict[str, Any] = {
|
||||
"ticker": self.ticker,
|
||||
"side": self.side.value,
|
||||
"quantity": self.quantity,
|
||||
"order_type": self.order_type.value,
|
||||
"time_in_force": self.time_in_force,
|
||||
"idempotency_key": self.idempotency_key,
|
||||
}
|
||||
if self.limit_price is not None:
|
||||
d["limit_price"] = self.limit_price
|
||||
if self.stop_price is not None:
|
||||
d["stop_price"] = self.stop_price
|
||||
return d
|
||||
|
||||
|
||||
class OrderResponse:
|
||||
"""Represents a broker's response to an order submission."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker_order_id: str,
|
||||
status: OrderStatus,
|
||||
ticker: str,
|
||||
side: OrderSide,
|
||||
quantity: float,
|
||||
filled_quantity: float = 0.0,
|
||||
filled_avg_price: float | None = None,
|
||||
submitted_at: datetime | None = None,
|
||||
raw_response: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
self.broker_order_id = broker_order_id
|
||||
self.status = status
|
||||
self.ticker = ticker
|
||||
self.side = side
|
||||
self.quantity = quantity
|
||||
self.filled_quantity = filled_quantity
|
||||
self.filled_avg_price = filled_avg_price
|
||||
self.submitted_at = submitted_at or datetime.now(timezone.utc)
|
||||
self.raw_response = raw_response or {}
|
||||
self.error = error
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return self.error is None and self.status not in (
|
||||
OrderStatus.REJECTED,
|
||||
OrderStatus.CANCELLED,
|
||||
OrderStatus.EXPIRED,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"broker_order_id": self.broker_order_id,
|
||||
"status": self.status.value,
|
||||
"ticker": self.ticker,
|
||||
"side": self.side.value,
|
||||
"quantity": self.quantity,
|
||||
"filled_quantity": self.filled_quantity,
|
||||
"filled_avg_price": self.filled_avg_price,
|
||||
"submitted_at": self.submitted_at.isoformat(),
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class PositionInfo:
|
||||
"""Represents a current position from the broker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: float,
|
||||
avg_entry_price: float,
|
||||
current_price: float,
|
||||
unrealized_pnl: float,
|
||||
market_value: float,
|
||||
side: str = "long",
|
||||
) -> None:
|
||||
self.ticker = ticker
|
||||
self.quantity = quantity
|
||||
self.avg_entry_price = avg_entry_price
|
||||
self.current_price = current_price
|
||||
self.unrealized_pnl = unrealized_pnl
|
||||
self.market_value = market_value
|
||||
self.side = side
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ticker": self.ticker,
|
||||
"quantity": self.quantity,
|
||||
"avg_entry_price": self.avg_entry_price,
|
||||
"current_price": self.current_price,
|
||||
"unrealized_pnl": self.unrealized_pnl,
|
||||
"market_value": self.market_value,
|
||||
"side": self.side,
|
||||
}
|
||||
|
||||
|
||||
class AccountInfo:
|
||||
"""Represents broker account summary."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account_id: str,
|
||||
buying_power: float,
|
||||
cash: float,
|
||||
portfolio_value: float,
|
||||
currency: str = "USD",
|
||||
mode: TradingMode = TradingMode.PAPER,
|
||||
) -> None:
|
||||
self.account_id = account_id
|
||||
self.buying_power = buying_power
|
||||
self.cash = cash
|
||||
self.portfolio_value = portfolio_value
|
||||
self.currency = currency
|
||||
self.mode = mode
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"account_id": self.account_id,
|
||||
"buying_power": self.buying_power,
|
||||
"cash": self.cash,
|
||||
"portfolio_value": self.portfolio_value,
|
||||
"currency": self.currency,
|
||||
"mode": self.mode.value,
|
||||
}
|
||||
|
||||
|
||||
# --- Abstract interface ---
|
||||
|
||||
|
||||
class BrokerDataAdapter(BaseAdapter, ABC):
|
||||
"""Abstract interface for broker API integrations.
|
||||
|
||||
Extends BaseAdapter with broker-specific operations:
|
||||
- submit_order: place an order with idempotency key
|
||||
- cancel_order: cancel an existing order
|
||||
- get_order_status: check order state
|
||||
- get_positions: list current positions
|
||||
- get_account: retrieve account summary
|
||||
|
||||
All concrete adapters must enforce:
|
||||
- Idempotent order submission via idempotency_key (Req 8.5)
|
||||
- Paper/live mode separation (Req 8.1)
|
||||
- Fail-closed on broker unavailability (Req 8.5)
|
||||
"""
|
||||
|
||||
def __init__(self, mode: TradingMode = TradingMode.PAPER) -> None:
|
||||
self._mode = mode
|
||||
|
||||
@property
|
||||
def mode(self) -> TradingMode:
|
||||
return self._mode
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "broker"
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
@abstractmethod
|
||||
async def submit_order(self, order: OrderRequest) -> OrderResponse:
|
||||
"""Submit an order to the broker.
|
||||
|
||||
Must use order.idempotency_key to prevent duplicate submissions.
|
||||
Must fail closed if the broker is unavailable or returns ambiguous state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def cancel_order(self, broker_order_id: str) -> OrderResponse:
|
||||
"""Cancel an existing order by broker order ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_order_status(self, broker_order_id: str) -> OrderResponse:
|
||||
"""Get the current status of an order."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_positions(self) -> list[PositionInfo]:
|
||||
"""Get all current positions."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_account(self) -> AccountInfo:
|
||||
"""Get account summary (balance, buying power, etc.)."""
|
||||
...
|
||||
|
||||
|
||||
# --- Concrete Alpaca implementation ---
|
||||
|
||||
|
||||
class AlpacaBrokerAdapter(BrokerDataAdapter):
|
||||
"""Concrete broker adapter for the Alpaca Markets REST API.
|
||||
|
||||
Supports:
|
||||
- Paper trading via paper-api.alpaca.markets
|
||||
- Live trading via api.alpaca.markets
|
||||
- Order submission, cancellation, and status
|
||||
- Position and account queries
|
||||
|
||||
Config options for fetch():
|
||||
endpoint: One of "positions", "orders", "account" (default "positions")
|
||||
"""
|
||||
|
||||
PAPER_BASE_URL: str = "https://paper-api.alpaca.markets"
|
||||
LIVE_BASE_URL: str = "https://api.alpaca.markets"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_secret: str,
|
||||
mode: TradingMode = TradingMode.PAPER,
|
||||
base_url: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(mode=mode)
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
if base_url:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
elif mode == TradingMode.LIVE:
|
||||
self.base_url = self.LIVE_BASE_URL
|
||||
else:
|
||||
self.base_url = self.PAPER_BASE_URL
|
||||
|
||||
def _headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"APCA-API-KEY-ID": self.api_key,
|
||||
"APCA-API-SECRET-KEY": self.api_secret,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch positions and recent orders for a ticker."""
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch positions or recent orders for a ticker from Alpaca.
|
||||
|
||||
This satisfies the BaseAdapter contract for the ingestion pipeline.
|
||||
The broker adapter uses fetch() to pull position/order snapshots
|
||||
that get persisted as raw artifacts.
|
||||
"""
|
||||
endpoint = config.get("endpoint", "positions")
|
||||
url = self._build_fetch_url(ticker, endpoint)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v2/positions/{ticker}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp = await client.get(url, headers=self._headers())
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
|
||||
raw = resp.content
|
||||
data = resp.json() if resp.status_code == 200 else {}
|
||||
data = resp.json()
|
||||
content_hash = hashlib.sha256(raw).hexdigest()
|
||||
items = [data] if isinstance(data, dict) else data if isinstance(data, list) else []
|
||||
|
||||
return AdapterResult(
|
||||
source_type="broker",
|
||||
ticker=ticker,
|
||||
items=[data] if data else [],
|
||||
items=items,
|
||||
raw_payload=raw,
|
||||
content_hash=content_hash,
|
||||
fetched_at=datetime.utcnow(),
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
http_status=resp.status_code,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={
|
||||
"provider": "alpaca",
|
||||
"mode": self._mode.value,
|
||||
"endpoint": endpoint,
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Alpaca HTTP error for %s: %s", ticker, e)
|
||||
return self._error_result(
|
||||
ticker, str(e), elapsed_ms,
|
||||
http_status=e.response.status_code if e.response else None,
|
||||
raw=e.response.content if e.response else b"",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Broker fetch failed for {ticker}: {e}")
|
||||
return AdapterResult(
|
||||
source_type="broker",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=b"",
|
||||
content_hash="",
|
||||
fetched_at=datetime.utcnow(),
|
||||
error=str(e),
|
||||
)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Alpaca fetch failed for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, str(e), elapsed_ms)
|
||||
|
||||
async def submit_order(
|
||||
self,
|
||||
ticker: str,
|
||||
side: str,
|
||||
qty: float,
|
||||
order_type: str = "market",
|
||||
limit_price: Optional[float] = None,
|
||||
idempotency_key: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Submit an order to the broker. Returns broker response."""
|
||||
if self.mode == "live":
|
||||
logger.warning("LIVE order submission")
|
||||
def _build_fetch_url(self, ticker: str, endpoint: str) -> str:
|
||||
"""Build the URL for a fetch operation."""
|
||||
if endpoint == "orders":
|
||||
return f"{self.base_url}/v2/orders?symbols={ticker}&status=all&limit=50"
|
||||
if endpoint == "account":
|
||||
return f"{self.base_url}/v2/account"
|
||||
# Default: positions for ticker
|
||||
return f"{self.base_url}/v2/positions/{ticker}"
|
||||
|
||||
idem_key = idempotency_key or str(uuid.uuid4())
|
||||
payload = {
|
||||
"symbol": ticker,
|
||||
"qty": str(qty),
|
||||
"side": side,
|
||||
"type": order_type,
|
||||
"time_in_force": "day",
|
||||
async def submit_order(self, order: OrderRequest) -> OrderResponse:
|
||||
"""Submit an order to Alpaca with idempotency key.
|
||||
|
||||
Fails closed: any network error or ambiguous response returns
|
||||
a rejected OrderResponse rather than risking duplicate orders.
|
||||
"""
|
||||
if self._mode == TradingMode.LIVE:
|
||||
logger.warning("LIVE order submission: %s %s %s", order.side.value, order.quantity, order.ticker)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"symbol": order.ticker,
|
||||
"qty": str(order.quantity),
|
||||
"side": order.side.value,
|
||||
"type": order.order_type.value,
|
||||
"time_in_force": order.time_in_force,
|
||||
}
|
||||
if limit_price and order_type == "limit":
|
||||
payload["limit_price"] = str(limit_price)
|
||||
if order.limit_price is not None and order.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT):
|
||||
payload["limit_price"] = str(order.limit_price)
|
||||
if order.stop_price is not None and order.order_type in (OrderType.STOP, OrderType.STOP_LIMIT):
|
||||
payload["stop_price"] = str(order.stop_price)
|
||||
|
||||
headers = {**self._headers(), "Idempotency-Key": order.idempotency_key}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/v2/orders",
|
||||
headers={**self._headers(), "Idempotency-Key": idem_key},
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
data = resp.json()
|
||||
return self._parse_order_response(data)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Order rejected: {e.response.text}")
|
||||
return {"error": e.response.text, "status": e.response.status_code}
|
||||
error_body = e.response.text if e.response else "unknown"
|
||||
logger.error("Order rejected by Alpaca: %s", error_body)
|
||||
return OrderResponse(
|
||||
broker_order_id="",
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
quantity=order.quantity,
|
||||
error=f"HTTP {e.response.status_code}: {error_body}" if e.response else str(e),
|
||||
raw_response={"error": error_body},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Order submission failed: {e}")
|
||||
return {"error": str(e)}
|
||||
# Fail closed: treat any unexpected error as rejection
|
||||
logger.error("Order submission failed (fail-closed): %s", e)
|
||||
return OrderResponse(
|
||||
broker_order_id="",
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
quantity=order.quantity,
|
||||
error=f"fail-closed: {e}",
|
||||
)
|
||||
|
||||
async def get_account(self) -> Dict[str, Any]:
|
||||
async def cancel_order(self, broker_order_id: str) -> OrderResponse:
|
||||
"""Cancel an order on Alpaca."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(f"{self.base_url}/v2/account", headers=self._headers())
|
||||
return resp.json()
|
||||
try:
|
||||
resp = await client.delete(
|
||||
f"{self.base_url}/v2/orders/{broker_order_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
if resp.status_code == 204:
|
||||
return OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.CANCELLED,
|
||||
ticker="",
|
||||
side=OrderSide.BUY,
|
||||
quantity=0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return self._parse_order_response(data)
|
||||
except Exception as e:
|
||||
logger.error("Cancel failed for %s: %s", broker_order_id, e)
|
||||
return OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker="",
|
||||
side=OrderSide.BUY,
|
||||
quantity=0,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def get_order_status(self, broker_order_id: str) -> OrderResponse:
|
||||
"""Get order status from Alpaca."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v2/orders/{broker_order_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return self._parse_order_response(data)
|
||||
except Exception as e:
|
||||
logger.error("Get order status failed for %s: %s", broker_order_id, e)
|
||||
return OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker="",
|
||||
side=OrderSide.BUY,
|
||||
quantity=0,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def get_positions(self) -> list[PositionInfo]:
|
||||
"""Get all current positions from Alpaca."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v2/positions",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
return [self._parse_position(p) for p in data if isinstance(p, dict)]
|
||||
except Exception as e:
|
||||
logger.error("Get positions failed: %s", e)
|
||||
return []
|
||||
|
||||
async def get_account(self) -> AccountInfo:
|
||||
"""Get account summary from Alpaca."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v2/account",
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return AccountInfo(
|
||||
account_id=str(data.get("id", "")),
|
||||
buying_power=float(data.get("buying_power", 0)),
|
||||
cash=float(data.get("cash", 0)),
|
||||
portfolio_value=float(data.get("portfolio_value", 0)),
|
||||
currency=str(data.get("currency", "USD")),
|
||||
mode=self._mode,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Get account failed: %s", e)
|
||||
return AccountInfo(
|
||||
account_id="",
|
||||
buying_power=0,
|
||||
cash=0,
|
||||
portfolio_value=0,
|
||||
mode=self._mode,
|
||||
)
|
||||
|
||||
def _parse_order_response(self, data: dict[str, Any]) -> OrderResponse:
|
||||
"""Parse an Alpaca order response into an OrderResponse."""
|
||||
status_map: dict[str, OrderStatus] = {
|
||||
"new": OrderStatus.SUBMITTED,
|
||||
"accepted": OrderStatus.ACCEPTED,
|
||||
"partially_filled": OrderStatus.PARTIALLY_FILLED,
|
||||
"filled": OrderStatus.FILLED,
|
||||
"done_for_day": OrderStatus.FILLED,
|
||||
"canceled": OrderStatus.CANCELLED,
|
||||
"expired": OrderStatus.EXPIRED,
|
||||
"replaced": OrderStatus.SUBMITTED,
|
||||
"pending_new": OrderStatus.PENDING,
|
||||
"pending_cancel": OrderStatus.PENDING,
|
||||
"pending_replace": OrderStatus.PENDING,
|
||||
"rejected": OrderStatus.REJECTED,
|
||||
}
|
||||
raw_status = str(data.get("status", "pending"))
|
||||
status = status_map.get(raw_status, OrderStatus.PENDING)
|
||||
|
||||
side_str = str(data.get("side", "buy"))
|
||||
side = OrderSide.SELL if side_str == "sell" else OrderSide.BUY
|
||||
|
||||
filled_qty = float(data.get("filled_qty", 0) or 0)
|
||||
filled_avg = data.get("filled_avg_price")
|
||||
filled_avg_price = float(filled_avg) if filled_avg else None
|
||||
|
||||
return OrderResponse(
|
||||
broker_order_id=str(data.get("id", "")),
|
||||
status=status,
|
||||
ticker=str(data.get("symbol", "")),
|
||||
side=side,
|
||||
quantity=float(data.get("qty", 0) or 0),
|
||||
filled_quantity=filled_qty,
|
||||
filled_avg_price=filled_avg_price,
|
||||
raw_response=data,
|
||||
)
|
||||
|
||||
def _parse_position(self, data: dict[str, Any]) -> PositionInfo:
|
||||
"""Parse an Alpaca position response into a PositionInfo."""
|
||||
return PositionInfo(
|
||||
ticker=str(data.get("symbol", "")),
|
||||
quantity=float(data.get("qty", 0) or 0),
|
||||
avg_entry_price=float(data.get("avg_entry_price", 0) or 0),
|
||||
current_price=float(data.get("current_price", 0) or 0),
|
||||
unrealized_pnl=float(data.get("unrealized_pl", 0) or 0),
|
||||
market_value=float(data.get("market_value", 0) or 0),
|
||||
side=str(data.get("side", "long")),
|
||||
)
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
ticker: str,
|
||||
error: str,
|
||||
elapsed_ms: float,
|
||||
http_status: int | None = None,
|
||||
raw: bytes = b"",
|
||||
) -> AdapterResult:
|
||||
"""Build an error AdapterResult for broker fetches."""
|
||||
return AdapterResult(
|
||||
source_type="broker",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=raw,
|
||||
content_hash="",
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
error=error,
|
||||
http_status=http_status,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={"provider": "alpaca", "mode": self._mode.value},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,832 @@
|
||||
"""Broker adapter service - standalone worker for sandbox order execution.
|
||||
|
||||
Runs the Alpaca broker adapter in sandbox (paper) mode, processing order
|
||||
requests from the broker queue, evaluating them through the risk engine,
|
||||
submitting to Alpaca's paper trading API, and persisting the full audit trail.
|
||||
|
||||
Also periodically syncs positions and account state from Alpaca.
|
||||
|
||||
Implements idempotent order submission keys and duplicate prevention:
|
||||
- Deterministic idempotency key generation from job attributes
|
||||
- Redis-based fast-path duplicate detection before broker submission
|
||||
- PostgreSQL UNIQUE constraint on idempotency_key as durable fallback
|
||||
|
||||
Requirements: 2.4, 8.1, 8.3, 8.5
|
||||
Design: Section 4.9 - Broker Adapter
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.adapters.broker_adapter import (
|
||||
AlpacaBrokerAdapter,
|
||||
OrderRequest,
|
||||
OrderResponse,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
TradingMode,
|
||||
)
|
||||
from services.risk.engine import (
|
||||
AccountRiskState,
|
||||
PortfolioRiskConfig,
|
||||
ProposedOrder,
|
||||
evaluate_order,
|
||||
)
|
||||
from services.risk.approval import (
|
||||
ApprovalRequest,
|
||||
ApprovalStatus,
|
||||
compute_expiry,
|
||||
create_approval_request,
|
||||
requires_approval,
|
||||
)
|
||||
from services.shared.audit import (
|
||||
audit_approval_requested,
|
||||
audit_duplicate_prevented,
|
||||
audit_order_filled,
|
||||
audit_order_rejected,
|
||||
audit_order_submitted,
|
||||
audit_risk_evaluated,
|
||||
)
|
||||
from services.lake_publisher.worker import (
|
||||
publish_trade_order,
|
||||
publish_trade_fill,
|
||||
publish_positions_daily_batch,
|
||||
LAKEHOUSE_BUCKET,
|
||||
)
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_pg_pool, get_redis
|
||||
from services.shared.logging import Span, new_trace_id, set_trace_context, setup_logging
|
||||
from services.shared.metrics import (
|
||||
ORDERS_DUPLICATES_PREVENTED,
|
||||
ORDERS_FILLED,
|
||||
ORDERS_REJECTED,
|
||||
ORDERS_SUBMITTED,
|
||||
POSITIONS_SYNCED,
|
||||
RISK_CHECK_FAILURES,
|
||||
RISK_EVALUATIONS_TOTAL,
|
||||
)
|
||||
from services.shared.redis_keys import QUEUE_BROKER, queue_key
|
||||
|
||||
logger = logging.getLogger("broker_service")
|
||||
|
||||
POSITION_SYNC_INTERVAL = 60 # seconds
|
||||
|
||||
# Redis TTL for idempotency markers (24 hours)
|
||||
ORDER_IDEMPOTENCY_TTL = 86400
|
||||
ORDER_IDEMPOTENCY_PREFIX = "stonks:order_idempotency"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB persistence helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UPSERT_BROKER_ACCOUNT = """
|
||||
INSERT INTO broker_accounts (id, provider, account_id, mode, config, active)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5::jsonb, TRUE)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
config = EXCLUDED.config,
|
||||
mode = EXCLUDED.mode,
|
||||
active = TRUE
|
||||
"""
|
||||
|
||||
_INSERT_ORDER = """
|
||||
INSERT INTO orders (
|
||||
id, recommendation_id, broker_account_id, ticker, side, order_type,
|
||||
quantity, limit_price, stop_price, status, idempotency_key,
|
||||
broker_order_id, decision_trace, submitted_at, filled_at,
|
||||
fill_price, fill_quantity
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3::uuid, $4, $5, $6,
|
||||
$7, $8, $9, $10, $11,
|
||||
$12, $13::jsonb, $14, $15,
|
||||
$16, $17
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO UPDATE SET
|
||||
status = EXCLUDED.status,
|
||||
broker_order_id = EXCLUDED.broker_order_id,
|
||||
filled_at = EXCLUDED.filled_at,
|
||||
fill_price = EXCLUDED.fill_price,
|
||||
fill_quantity = EXCLUDED.fill_quantity,
|
||||
updated_at = NOW()
|
||||
"""
|
||||
|
||||
_INSERT_ORDER_EVENT = """
|
||||
INSERT INTO order_events (order_id, event_type, data, broker_timestamp)
|
||||
VALUES ($1::uuid, $2, $3::jsonb, $4)
|
||||
"""
|
||||
|
||||
_INSERT_RISK_EVALUATION = """
|
||||
INSERT INTO risk_evaluations (id, recommendation_id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb, $6::jsonb, $7)
|
||||
"""
|
||||
|
||||
_UPSERT_POSITION = """
|
||||
INSERT INTO positions (broker_account_id, ticker, quantity, avg_entry_price, current_price, unrealized_pnl, updated_at)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (broker_account_id, ticker)
|
||||
DO UPDATE SET
|
||||
quantity = EXCLUDED.quantity,
|
||||
avg_entry_price = EXCLUDED.avg_entry_price,
|
||||
current_price = EXCLUDED.current_price,
|
||||
unrealized_pnl = EXCLUDED.unrealized_pnl,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
"""
|
||||
|
||||
_LOAD_RISK_CONFIG = """
|
||||
SELECT config FROM risk_configs WHERE active = TRUE ORDER BY updated_at DESC LIMIT 1
|
||||
"""
|
||||
|
||||
_LOAD_DAILY_SNAPSHOT = """
|
||||
SELECT portfolio_value, daily_pnl, daily_trade_count, positions_by_sector
|
||||
FROM daily_risk_snapshots
|
||||
WHERE account_id = $1 AND snapshot_date = CURRENT_DATE
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_CHECK_ORDER_BY_IDEMPOTENCY_KEY = """
|
||||
SELECT id, status, broker_order_id FROM orders
|
||||
WHERE idempotency_key = $1
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Idempotency helpers (Requirement 8.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_idempotency_key(job: dict[str, Any]) -> str:
|
||||
"""Generate a deterministic idempotency key from job attributes.
|
||||
|
||||
If the job already carries an explicit idempotency_key, use it.
|
||||
Otherwise, derive a stable key from the combination of
|
||||
recommendation_id, ticker, side, quantity, and order_type so that
|
||||
replayed queue messages produce the same key and are detected as
|
||||
duplicates.
|
||||
"""
|
||||
explicit = job.get("idempotency_key")
|
||||
if explicit:
|
||||
return str(explicit)
|
||||
|
||||
# Build a deterministic key from job content
|
||||
parts = [
|
||||
str(job.get("recommendation_id", "")),
|
||||
str(job.get("ticker", "")),
|
||||
str(job.get("side", "buy")),
|
||||
str(job.get("quantity", 0)),
|
||||
str(job.get("order_type", "market")),
|
||||
str(job.get("limit_price", "")),
|
||||
str(job.get("stop_price", "")),
|
||||
]
|
||||
raw = "|".join(parts)
|
||||
return hashlib.sha256(raw.encode()).hexdigest()[:40]
|
||||
|
||||
|
||||
def _redis_idempotency_key(idempotency_key: str) -> str:
|
||||
"""Build the Redis key for an order idempotency marker."""
|
||||
return f"{ORDER_IDEMPOTENCY_PREFIX}:{idempotency_key}"
|
||||
|
||||
|
||||
async def check_idempotency_redis(
|
||||
rds: aioredis.Redis,
|
||||
idempotency_key: str,
|
||||
) -> str | None:
|
||||
"""Fast-path: check Redis for a previously processed idempotency key.
|
||||
|
||||
Returns the existing order_id if found, None otherwise.
|
||||
"""
|
||||
redis_key = _redis_idempotency_key(idempotency_key)
|
||||
cached = await rds.get(redis_key)
|
||||
if cached:
|
||||
return str(cached)
|
||||
return None
|
||||
|
||||
|
||||
async def check_idempotency_db(
|
||||
pool: asyncpg.Pool,
|
||||
idempotency_key: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Durable fallback: check PostgreSQL for an existing order with this key.
|
||||
|
||||
Returns a dict with id, status, broker_order_id if found, None otherwise.
|
||||
"""
|
||||
row = await pool.fetchrow(_CHECK_ORDER_BY_IDEMPOTENCY_KEY, idempotency_key)
|
||||
if row:
|
||||
return {
|
||||
"id": str(row["id"]),
|
||||
"status": str(row["status"]),
|
||||
"broker_order_id": str(row["broker_order_id"] or ""),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def mark_idempotency_redis(
|
||||
rds: aioredis.Redis,
|
||||
idempotency_key: str,
|
||||
order_id: str,
|
||||
) -> None:
|
||||
"""Set the Redis idempotency marker after an order is processed."""
|
||||
redis_key = _redis_idempotency_key(idempotency_key)
|
||||
await rds.set(redis_key, order_id, ex=ORDER_IDEMPOTENCY_TTL)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core service logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_order_request(job: dict[str, Any]) -> OrderRequest:
|
||||
"""Build an OrderRequest from a broker queue job payload."""
|
||||
side = OrderSide.SELL if job.get("side", "buy") == "sell" else OrderSide.BUY
|
||||
order_type_str = job.get("order_type", "market")
|
||||
order_type_map = {
|
||||
"market": OrderType.MARKET,
|
||||
"limit": OrderType.LIMIT,
|
||||
"stop": OrderType.STOP,
|
||||
"stop_limit": OrderType.STOP_LIMIT,
|
||||
}
|
||||
return OrderRequest(
|
||||
ticker=job["ticker"],
|
||||
side=side,
|
||||
quantity=float(job.get("quantity", 0)),
|
||||
order_type=order_type_map.get(order_type_str, OrderType.MARKET),
|
||||
limit_price=job.get("limit_price"),
|
||||
stop_price=job.get("stop_price"),
|
||||
time_in_force=job.get("time_in_force", "day"),
|
||||
idempotency_key=generate_idempotency_key(job),
|
||||
)
|
||||
|
||||
|
||||
def build_proposed_order(job: dict[str, Any]) -> ProposedOrder:
|
||||
"""Build a ProposedOrder for risk evaluation from a broker queue job."""
|
||||
return ProposedOrder(
|
||||
recommendation_id=job.get("recommendation_id"),
|
||||
ticker=job["ticker"],
|
||||
sector=job.get("sector", ""),
|
||||
action=job.get("side", "buy"),
|
||||
quantity=float(job.get("quantity", 0)),
|
||||
estimated_value=float(job.get("estimated_value", 0)),
|
||||
confidence=float(job.get("confidence", 0)),
|
||||
)
|
||||
|
||||
|
||||
async def load_risk_config(pool: asyncpg.Pool) -> PortfolioRiskConfig:
|
||||
"""Load the active risk configuration from the database."""
|
||||
row = await pool.fetchrow(_LOAD_RISK_CONFIG)
|
||||
if row and row["config"]:
|
||||
data = row["config"] if isinstance(row["config"], dict) else json.loads(row["config"])
|
||||
return PortfolioRiskConfig.from_db_json(data)
|
||||
return PortfolioRiskConfig()
|
||||
|
||||
|
||||
async def load_account_risk_state(
|
||||
pool: asyncpg.Pool,
|
||||
adapter: AlpacaBrokerAdapter,
|
||||
account_uuid: str,
|
||||
) -> AccountRiskState:
|
||||
"""Build an AccountRiskState from the broker and daily snapshot."""
|
||||
state = AccountRiskState(account_id=account_uuid)
|
||||
|
||||
# Get live account info from Alpaca
|
||||
try:
|
||||
acct = await adapter.get_account()
|
||||
state.portfolio_value = acct.portfolio_value
|
||||
state.cash = acct.cash
|
||||
state.buying_power = acct.buying_power
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch account from Alpaca: %s", e)
|
||||
|
||||
# Get positions from Alpaca
|
||||
try:
|
||||
positions = await adapter.get_positions()
|
||||
for pos in positions:
|
||||
state.positions_by_symbol[pos.ticker] = pos.market_value
|
||||
state.open_position_count = len(positions)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch positions from Alpaca: %s", e)
|
||||
|
||||
# Overlay daily snapshot from DB
|
||||
row = await pool.fetchrow(_LOAD_DAILY_SNAPSHOT, account_uuid)
|
||||
if row:
|
||||
state.daily_pnl = float(row["daily_pnl"] or 0)
|
||||
state.daily_trade_count = int(row["daily_trade_count"] or 0)
|
||||
sector_data = row["positions_by_sector"]
|
||||
if sector_data:
|
||||
state.positions_by_sector = (
|
||||
sector_data if isinstance(sector_data, dict) else json.loads(sector_data)
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def persist_order(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
order: OrderRequest,
|
||||
resp: OrderResponse,
|
||||
account_uuid: str,
|
||||
risk_eval: dict[str, Any],
|
||||
recommendation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Persist order, events, and risk evaluation to PostgreSQL."""
|
||||
now = datetime.now(timezone.utc)
|
||||
filled_at = now if resp.status == OrderStatus.FILLED else None
|
||||
|
||||
decision_trace = {
|
||||
"risk_evaluation": risk_eval,
|
||||
"order_request": order.to_dict(),
|
||||
"broker_response": resp.to_dict(),
|
||||
}
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
_INSERT_ORDER,
|
||||
order_id,
|
||||
recommendation_id,
|
||||
account_uuid,
|
||||
order.ticker,
|
||||
order.side.value,
|
||||
order.order_type.value,
|
||||
order.quantity,
|
||||
order.limit_price,
|
||||
order.stop_price,
|
||||
resp.status.value,
|
||||
order.idempotency_key,
|
||||
resp.broker_order_id,
|
||||
json.dumps(decision_trace),
|
||||
resp.submitted_at or now,
|
||||
filled_at,
|
||||
resp.filled_avg_price,
|
||||
resp.filled_quantity,
|
||||
)
|
||||
|
||||
# Record order events
|
||||
for event_type in ["submitted"]:
|
||||
await conn.execute(
|
||||
_INSERT_ORDER_EVENT,
|
||||
order_id,
|
||||
event_type,
|
||||
json.dumps({"ticker": order.ticker, "side": order.side.value}),
|
||||
now,
|
||||
)
|
||||
|
||||
if resp.status == OrderStatus.FILLED:
|
||||
await conn.execute(
|
||||
_INSERT_ORDER_EVENT,
|
||||
order_id,
|
||||
"fill",
|
||||
json.dumps({
|
||||
"fill_price": resp.filled_avg_price,
|
||||
"fill_qty": resp.filled_quantity,
|
||||
}),
|
||||
now,
|
||||
)
|
||||
elif resp.status == OrderStatus.REJECTED:
|
||||
await conn.execute(
|
||||
_INSERT_ORDER_EVENT,
|
||||
order_id,
|
||||
"rejected",
|
||||
json.dumps({"error": resp.error}),
|
||||
now,
|
||||
)
|
||||
|
||||
|
||||
async def sync_positions(
|
||||
adapter: AlpacaBrokerAdapter,
|
||||
pool: asyncpg.Pool,
|
||||
account_uuid: str,
|
||||
minio_client: Any | None = None,
|
||||
) -> None:
|
||||
"""Sync current positions from Alpaca to PostgreSQL and publish to lake."""
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
positions = await adapter.get_positions()
|
||||
async with pool.acquire() as conn:
|
||||
for pos in positions:
|
||||
await conn.execute(
|
||||
_UPSERT_POSITION,
|
||||
account_uuid,
|
||||
pos.ticker,
|
||||
pos.quantity,
|
||||
pos.avg_entry_price,
|
||||
pos.current_price,
|
||||
pos.unrealized_pnl,
|
||||
now,
|
||||
)
|
||||
logger.info("Synced %d positions from Alpaca", len(positions))
|
||||
POSITIONS_SYNCED.inc()
|
||||
|
||||
# Publish positions snapshot to analytical lake
|
||||
if minio_client is not None and positions:
|
||||
try:
|
||||
pos_dicts = [
|
||||
{
|
||||
"ticker": p.ticker,
|
||||
"quantity": p.quantity,
|
||||
"avg_entry_price": p.avg_entry_price,
|
||||
"close_price": p.current_price,
|
||||
"unrealized_pnl": p.unrealized_pnl,
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
publish_positions_daily_batch(
|
||||
minio_client, pos_dicts, account_uuid, now,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish positions to lake: %s", e)
|
||||
except Exception as e:
|
||||
logger.error("Position sync failed: %s", e)
|
||||
|
||||
|
||||
async def register_broker_account(
|
||||
pool: asyncpg.Pool,
|
||||
account_uuid: str,
|
||||
adapter: AlpacaBrokerAdapter,
|
||||
) -> None:
|
||||
"""Register or update the broker account in PostgreSQL."""
|
||||
try:
|
||||
acct = await adapter.get_account()
|
||||
config_json = json.dumps({
|
||||
"provider": "alpaca",
|
||||
"buying_power": acct.buying_power,
|
||||
"cash": acct.cash,
|
||||
"portfolio_value": acct.portfolio_value,
|
||||
})
|
||||
await pool.execute(
|
||||
_UPSERT_BROKER_ACCOUNT,
|
||||
account_uuid,
|
||||
"alpaca",
|
||||
acct.account_id or account_uuid,
|
||||
adapter.mode.value,
|
||||
config_json,
|
||||
)
|
||||
logger.info(
|
||||
"Registered Alpaca account: id=%s mode=%s portfolio=%.2f",
|
||||
acct.account_id, adapter.mode.value, acct.portfolio_value,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to register broker account: %s", e)
|
||||
|
||||
|
||||
async def process_order_job(
|
||||
job: dict[str, Any],
|
||||
adapter: AlpacaBrokerAdapter,
|
||||
pool: asyncpg.Pool,
|
||||
account_uuid: str,
|
||||
rds: aioredis.Redis | None = None,
|
||||
minio_client: Any | None = None,
|
||||
) -> None:
|
||||
"""Process a single order job from the broker queue.
|
||||
|
||||
1. Generate deterministic idempotency key
|
||||
2. Check Redis + DB for duplicate (Req 8.5)
|
||||
3. Build proposed order and run risk evaluation
|
||||
4. If risk passes, submit to Alpaca
|
||||
5. Persist order, events, and risk evaluation
|
||||
6. Set Redis idempotency marker
|
||||
"""
|
||||
ticker = job.get("ticker", "???")
|
||||
order_id = str(uuid.uuid4())
|
||||
idempotency_key = generate_idempotency_key(job)
|
||||
|
||||
# --- Duplicate prevention (Requirement 8.5) ---
|
||||
# Fast path: Redis check
|
||||
if rds is not None:
|
||||
existing_order_id = await check_idempotency_redis(rds, idempotency_key)
|
||||
if existing_order_id:
|
||||
logger.info(
|
||||
"Duplicate order detected (redis) for %s key=%s existing=%s",
|
||||
ticker, idempotency_key[:16], existing_order_id,
|
||||
)
|
||||
ORDERS_DUPLICATES_PREVENTED.labels(detected_via="redis").inc()
|
||||
await audit_duplicate_prevented(
|
||||
pool, existing_order_id, ticker, idempotency_key, detected_via="redis",
|
||||
)
|
||||
return
|
||||
|
||||
# Durable fallback: DB check
|
||||
existing = await check_idempotency_db(pool, idempotency_key)
|
||||
if existing:
|
||||
logger.info(
|
||||
"Duplicate order detected (db) for %s key=%s existing=%s status=%s",
|
||||
ticker, idempotency_key[:16], existing["id"], existing["status"],
|
||||
)
|
||||
ORDERS_DUPLICATES_PREVENTED.labels(detected_via="db").inc()
|
||||
await audit_duplicate_prevented(
|
||||
pool, existing["id"], ticker, idempotency_key, detected_via="db",
|
||||
)
|
||||
# Warm Redis cache for future fast-path hits
|
||||
if rds is not None:
|
||||
await mark_idempotency_redis(rds, idempotency_key, existing["id"])
|
||||
return
|
||||
|
||||
# Risk evaluation
|
||||
risk_config = await load_risk_config(pool)
|
||||
risk_state = await load_account_risk_state(pool, adapter, account_uuid)
|
||||
proposed = build_proposed_order(job)
|
||||
evaluation = evaluate_order(proposed, risk_config, risk_state)
|
||||
|
||||
risk_eval_dict = {
|
||||
"evaluation_id": evaluation.evaluation_id,
|
||||
"eligible": evaluation.eligible,
|
||||
"allowed_mode": evaluation.allowed_mode.value,
|
||||
"rejection_reasons": evaluation.rejection_reasons,
|
||||
"checks": [c.model_dump(mode="json") for c in evaluation.checks],
|
||||
}
|
||||
|
||||
# Persist risk evaluation
|
||||
rec_id = job.get("recommendation_id")
|
||||
try:
|
||||
await pool.execute(
|
||||
_INSERT_RISK_EVALUATION,
|
||||
evaluation.evaluation_id,
|
||||
rec_id,
|
||||
evaluation.eligible,
|
||||
evaluation.allowed_mode.value,
|
||||
json.dumps(evaluation.rejection_reasons),
|
||||
json.dumps(risk_eval_dict["checks"]),
|
||||
evaluation.evaluated_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist risk evaluation: %s", e)
|
||||
|
||||
# Audit: risk evaluation result
|
||||
await audit_risk_evaluated(
|
||||
pool,
|
||||
evaluation_id=evaluation.evaluation_id,
|
||||
recommendation_id=rec_id,
|
||||
ticker=ticker,
|
||||
eligible=evaluation.eligible,
|
||||
allowed_mode=evaluation.allowed_mode.value,
|
||||
rejection_reasons=evaluation.rejection_reasons,
|
||||
check_count=len(evaluation.checks),
|
||||
)
|
||||
|
||||
if not evaluation.eligible:
|
||||
RISK_EVALUATIONS_TOTAL.labels(result="rejected").inc()
|
||||
for check in evaluation.checks:
|
||||
if check.result.value == "fail":
|
||||
RISK_CHECK_FAILURES.labels(check_name=check.check_name).inc()
|
||||
ORDERS_REJECTED.labels(reason_category="risk_engine").inc()
|
||||
logger.info(
|
||||
"Order rejected by risk engine for %s: %s",
|
||||
ticker, evaluation.rejection_reasons,
|
||||
)
|
||||
# Persist the rejected order for audit
|
||||
order_req = build_order_request(job)
|
||||
rejected_resp = OrderResponse(
|
||||
broker_order_id="",
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker=ticker,
|
||||
side=OrderSide.SELL if job.get("side") == "sell" else OrderSide.BUY,
|
||||
quantity=float(job.get("quantity", 0)),
|
||||
error=f"Risk rejected: {'; '.join(evaluation.rejection_reasons)}",
|
||||
)
|
||||
await persist_order(
|
||||
pool, order_id, order_req, rejected_resp,
|
||||
account_uuid, risk_eval_dict, rec_id,
|
||||
)
|
||||
# Publish rejected order fact to analytical lake
|
||||
if minio_client is not None:
|
||||
try:
|
||||
publish_trade_order(
|
||||
minio_client, order_id, ticker,
|
||||
side=job.get("side", "buy"),
|
||||
order_type=job.get("order_type", "market"),
|
||||
quantity=float(job.get("quantity", 0)),
|
||||
limit_price=job.get("limit_price"),
|
||||
status="rejected",
|
||||
broker_account=account_uuid,
|
||||
submitted_at=datetime.now(timezone.utc),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish rejected order to lake: %s", e)
|
||||
# Audit: order rejected by risk engine
|
||||
await audit_order_rejected(
|
||||
pool, order_id, ticker,
|
||||
reason=f"Risk rejected: {'; '.join(evaluation.rejection_reasons)}",
|
||||
source="risk_engine",
|
||||
)
|
||||
# Mark idempotency even for rejected orders to prevent reprocessing
|
||||
if rds is not None:
|
||||
await mark_idempotency_redis(rds, idempotency_key, order_id)
|
||||
return
|
||||
|
||||
# --- Operator approval gate (Requirement 8.2) ---
|
||||
if requires_approval(risk_config, evaluation.allowed_mode):
|
||||
expiry = compute_expiry(risk_config)
|
||||
approval_req = ApprovalRequest(
|
||||
order_job=job,
|
||||
recommendation_id=rec_id,
|
||||
ticker=ticker,
|
||||
side=job.get("side", "buy"),
|
||||
quantity=float(job.get("quantity", 0)),
|
||||
estimated_value=float(job.get("estimated_value", 0)),
|
||||
risk_evaluation_id=evaluation.evaluation_id,
|
||||
expires_at=expiry,
|
||||
)
|
||||
try:
|
||||
await create_approval_request(pool, approval_req)
|
||||
logger.info(
|
||||
"Order for %s held for operator approval (id=%s, expires=%s)",
|
||||
ticker, approval_req.approval_id, expiry.isoformat(),
|
||||
)
|
||||
await audit_approval_requested(
|
||||
pool,
|
||||
approval_id=approval_req.approval_id,
|
||||
ticker=ticker,
|
||||
side=approval_req.side,
|
||||
quantity=approval_req.quantity,
|
||||
estimated_value=approval_req.estimated_value,
|
||||
recommendation_id=rec_id,
|
||||
expires_at=expiry.isoformat(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create approval request for %s: %s", ticker, e)
|
||||
# Do NOT mark idempotency — the job will be re-submitted after approval
|
||||
return
|
||||
|
||||
# Submit to Alpaca
|
||||
order_req = build_order_request(job)
|
||||
RISK_EVALUATIONS_TOTAL.labels(result="passed").inc()
|
||||
|
||||
# Audit: order submitted to broker
|
||||
await audit_order_submitted(
|
||||
pool,
|
||||
order_id=order_id,
|
||||
ticker=ticker,
|
||||
side=order_req.side.value,
|
||||
quantity=order_req.quantity,
|
||||
order_type=order_req.order_type.value,
|
||||
idempotency_key=order_req.idempotency_key,
|
||||
recommendation_id=rec_id,
|
||||
evaluation_id=evaluation.evaluation_id,
|
||||
)
|
||||
|
||||
resp = await adapter.submit_order(order_req)
|
||||
|
||||
await persist_order(
|
||||
pool, order_id, order_req, resp,
|
||||
account_uuid, risk_eval_dict, rec_id,
|
||||
)
|
||||
|
||||
# Publish order fact to analytical lake
|
||||
if minio_client is not None:
|
||||
try:
|
||||
publish_trade_order(
|
||||
minio_client, order_id, ticker,
|
||||
side=order_req.side.value,
|
||||
order_type=order_req.order_type.value,
|
||||
quantity=order_req.quantity,
|
||||
limit_price=order_req.limit_price,
|
||||
status=resp.status.value,
|
||||
broker_account=account_uuid,
|
||||
submitted_at=resp.submitted_at or datetime.now(timezone.utc),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish order to lake: %s", e)
|
||||
|
||||
# Publish fill fact if the order was filled
|
||||
if resp.status == OrderStatus.FILLED and resp.filled_avg_price is not None:
|
||||
try:
|
||||
fill_id = str(uuid.uuid4())
|
||||
publish_trade_fill(
|
||||
minio_client, fill_id, order_id, ticker,
|
||||
side=order_req.side.value,
|
||||
fill_price=resp.filled_avg_price,
|
||||
fill_quantity=resp.filled_quantity,
|
||||
broker_account=account_uuid,
|
||||
filled_at=datetime.now(timezone.utc),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish fill to lake: %s", e)
|
||||
|
||||
# Mark idempotency after successful persistence
|
||||
if rds is not None:
|
||||
await mark_idempotency_redis(rds, idempotency_key, order_id)
|
||||
|
||||
if resp.ok:
|
||||
mode = "paper" if adapter.mode == TradingMode.PAPER else "live"
|
||||
ORDERS_SUBMITTED.labels(
|
||||
side=order_req.side.value,
|
||||
order_type=order_req.order_type.value,
|
||||
mode=mode,
|
||||
).inc()
|
||||
logger.info(
|
||||
"Order submitted to Alpaca: %s %s %.0f %s @ %s | broker_id=%s",
|
||||
resp.status.value, order_req.side.value, order_req.quantity,
|
||||
ticker, resp.filled_avg_price, resp.broker_order_id,
|
||||
)
|
||||
# Audit: order filled
|
||||
if resp.status == OrderStatus.FILLED:
|
||||
ORDERS_FILLED.labels(side=order_req.side.value).inc()
|
||||
await audit_order_filled(
|
||||
pool, order_id, ticker,
|
||||
side=order_req.side.value,
|
||||
fill_quantity=resp.filled_quantity,
|
||||
fill_price=resp.filled_avg_price,
|
||||
broker_order_id=resp.broker_order_id,
|
||||
)
|
||||
else:
|
||||
ORDERS_REJECTED.labels(reason_category="broker").inc()
|
||||
logger.warning(
|
||||
"Order failed for %s: %s (status=%s)",
|
||||
ticker, resp.error, resp.status.value,
|
||||
)
|
||||
# Audit: order rejected by broker
|
||||
await audit_order_rejected(
|
||||
pool, order_id, ticker,
|
||||
reason=resp.error or f"Broker status: {resp.status.value}",
|
||||
source="broker",
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def position_sync_loop(
|
||||
adapter: AlpacaBrokerAdapter,
|
||||
pool: asyncpg.Pool,
|
||||
account_uuid: str,
|
||||
minio_client: Any | None = None,
|
||||
) -> None:
|
||||
"""Periodically sync positions from Alpaca to PostgreSQL and lake."""
|
||||
while True:
|
||||
await sync_positions(adapter, pool, account_uuid, minio_client)
|
||||
await asyncio.sleep(POSITION_SYNC_INTERVAL)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("broker_service", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await get_pg_pool(config)
|
||||
rds = get_redis(config)
|
||||
|
||||
# Initialize MinIO client for lake publishing
|
||||
from minio import Minio
|
||||
minio_client = Minio(
|
||||
config.minio.endpoint,
|
||||
access_key=config.minio.access_key,
|
||||
secret_key=config.minio.secret_key,
|
||||
secure=config.minio.secure,
|
||||
)
|
||||
# Ensure lakehouse bucket exists
|
||||
if not minio_client.bucket_exists(LAKEHOUSE_BUCKET):
|
||||
minio_client.make_bucket(LAKEHOUSE_BUCKET)
|
||||
|
||||
# Determine mode — default to paper for safety (Req 8.1)
|
||||
mode = TradingMode.LIVE if config.broker.mode == "live" else TradingMode.PAPER
|
||||
if mode == TradingMode.LIVE:
|
||||
logger.warning("LIVE trading mode enabled — orders will be submitted to real broker")
|
||||
|
||||
adapter = AlpacaBrokerAdapter(
|
||||
api_key=config.broker.api_key or "",
|
||||
api_secret=config.broker.api_secret or "",
|
||||
mode=mode,
|
||||
base_url=config.broker.base_url,
|
||||
)
|
||||
|
||||
# Generate a stable account UUID from the API key
|
||||
account_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"alpaca-{config.broker.api_key or 'default'}"))
|
||||
|
||||
# Register broker account on startup
|
||||
await register_broker_account(pool, account_uuid, adapter)
|
||||
|
||||
# Start position sync in background
|
||||
sync_task = asyncio.create_task(
|
||||
position_sync_loop(adapter, pool, account_uuid, minio_client)
|
||||
)
|
||||
|
||||
queue = queue_key(QUEUE_BROKER)
|
||||
logger.info("Broker service started (mode=%s)", mode.value)
|
||||
|
||||
try:
|
||||
while True:
|
||||
result = await rds.lpop(queue)
|
||||
raw = str(result) if result else None
|
||||
if raw:
|
||||
try:
|
||||
job = json.loads(raw)
|
||||
await process_order_job(job, adapter, pool, account_uuid, rds, minio_client)
|
||||
except Exception:
|
||||
logger.exception("Error processing broker job")
|
||||
else:
|
||||
await asyncio.sleep(2)
|
||||
finally:
|
||||
sync_task.cancel()
|
||||
await pool.close()
|
||||
await rds.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,8 +1,17 @@
|
||||
"""Filings / Regulatory API adapter - fetches SEC-style submissions."""
|
||||
"""Filings / Regulatory API adapter interface and concrete SEC EDGAR provider.
|
||||
|
||||
The FilingsDataAdapter is the abstract interface for all filings data providers.
|
||||
SECEdgarAdapter is the first concrete implementation, targeting the SEC EDGAR
|
||||
full-text search system (EFTS) for company filings discovery.
|
||||
|
||||
Requirements: 2.3, 2.5, 3.1, 3.2, 3.3
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
import time
|
||||
from abc import ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -11,48 +20,182 @@ from .base import AdapterResult, BaseAdapter
|
||||
logger = logging.getLogger("filings_adapter")
|
||||
|
||||
|
||||
class FilingsAdapter(BaseAdapter):
|
||||
"""Concrete adapter for SEC EDGAR or similar filings API."""
|
||||
class FilingsDataAdapter(BaseAdapter, ABC):
|
||||
"""Abstract interface for filings / regulatory data providers.
|
||||
|
||||
def __init__(self, base_url: str = "https://efts.sec.gov", user_agent: str = "StonksOracle/1.0"):
|
||||
self.base_url = base_url
|
||||
self.user_agent = user_agent
|
||||
Subclasses implement fetch() for their specific filings API.
|
||||
source_type() is concrete here since all filings adapters share the same type.
|
||||
"""
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "filings_api"
|
||||
|
||||
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
|
||||
_cik = config.get("cik", "")
|
||||
endpoint = config.get("endpoint", f"/LATEST/search-index?q=%22{ticker}%22&dateRange=custom&startdt=2026-01-01&forms=8-K,10-Q,10-K")
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
headers = {"User-Agent": self.user_agent}
|
||||
class SECEdgarAdapter(FilingsDataAdapter):
|
||||
"""Concrete adapter for the SEC EDGAR full-text search system (EFTS).
|
||||
|
||||
Supports:
|
||||
- Full-text search (/LATEST/search-index) for 8-K, 10-Q, 10-K, and other forms
|
||||
- Filtering by date range, form type, and entity
|
||||
|
||||
The SEC EDGAR EFTS API is public and does not require an API key,
|
||||
but requires a descriptive User-Agent header per SEC fair-access policy.
|
||||
|
||||
Config options:
|
||||
cik: Company CIK number (optional, narrows search)
|
||||
forms: Comma-separated form types to search (default "8-K,10-Q,10-K")
|
||||
start_date: Only filings on or after this date, YYYY-MM-DD (optional)
|
||||
end_date: Only filings on or before this date, YYYY-MM-DD (optional)
|
||||
query: Custom search query override (optional, replaces ticker-based query)
|
||||
"""
|
||||
|
||||
SEARCH_ENDPOINT: str = "/LATEST/search-index"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "https://efts.sec.gov",
|
||||
user_agent: str = "StonksOracle/1.0 ([email])",
|
||||
) -> None:
|
||||
self.base_url: str = base_url.rstrip("/")
|
||||
self.user_agent: str = user_agent
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch filings from SEC EDGAR EFTS for a given ticker.
|
||||
|
||||
Args:
|
||||
ticker: The company ticker symbol.
|
||||
config: Source-specific configuration from the sources table.
|
||||
|
||||
Returns:
|
||||
AdapterResult with raw payload, parsed filing items, and metadata.
|
||||
"""
|
||||
url, params, headers = self._build_request(ticker, config)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = await client.get(url, headers=headers)
|
||||
resp = await client.get(url, params=params, headers=headers)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
|
||||
raw = resp.content
|
||||
data = resp.json()
|
||||
content_hash = hashlib.sha256(raw).hexdigest()
|
||||
items = self._extract_items(data)
|
||||
|
||||
hits = data.get("hits", {}).get("hits", [])
|
||||
return AdapterResult(
|
||||
source_type="filings_api",
|
||||
ticker=ticker,
|
||||
items=hits,
|
||||
items=items,
|
||||
raw_payload=raw,
|
||||
content_hash=content_hash,
|
||||
fetched_at=datetime.utcnow(),
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
http_status=resp.status_code,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={
|
||||
"provider": "sec_edgar",
|
||||
"results_count": len(items),
|
||||
"total_hits": self._total_hits(data),
|
||||
"query": params.get("q", ""),
|
||||
"forms": params.get("forms", ""),
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("SEC EDGAR HTTP error for %s: %s", ticker, e)
|
||||
return self._error_result(
|
||||
ticker, str(e), elapsed_ms,
|
||||
http_status=e.response.status_code if e.response else None,
|
||||
raw=e.response.content if e.response else b"",
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("SEC EDGAR timeout for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, f"timeout: {e}", elapsed_ms)
|
||||
except Exception as e:
|
||||
logger.error(f"Filings fetch failed for {ticker}: {e}")
|
||||
return AdapterResult(
|
||||
source_type="filings_api",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=b"",
|
||||
content_hash="",
|
||||
fetched_at=datetime.utcnow(),
|
||||
error=str(e),
|
||||
)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("SEC EDGAR fetch failed for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, str(e), elapsed_ms)
|
||||
|
||||
def _build_request(
|
||||
self, ticker: str, config: dict[str, Any]
|
||||
) -> tuple[str, dict[str, str], dict[str, str]]:
|
||||
"""Build the URL, query params, and headers for an EDGAR EFTS request."""
|
||||
params: dict[str, str] = {}
|
||||
headers: dict[str, str] = {"User-Agent": self.user_agent}
|
||||
|
||||
# Query: use custom override or default to ticker-based search
|
||||
query = config.get("query")
|
||||
if query:
|
||||
params["q"] = str(query)
|
||||
else:
|
||||
params["q"] = f'"{ticker}"'
|
||||
|
||||
# Form types filter
|
||||
forms = config.get("forms", "8-K,10-Q,10-K")
|
||||
params["forms"] = str(forms)
|
||||
|
||||
# Date range
|
||||
if config.get("start_date"):
|
||||
params["dateRange"] = "custom"
|
||||
params["startdt"] = str(config["start_date"])
|
||||
if config.get("end_date"):
|
||||
params["dateRange"] = "custom"
|
||||
params["enddt"] = str(config["end_date"])
|
||||
|
||||
# CIK filter (entity-level narrowing)
|
||||
cik = config.get("cik")
|
||||
if cik:
|
||||
params["q"] = f'{params["q"]} AND cik:{cik}'
|
||||
|
||||
url = f"{self.base_url}{self.SEARCH_ENDPOINT}"
|
||||
return url, params, headers
|
||||
|
||||
def _extract_items(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract the filing hits from an EDGAR EFTS response.
|
||||
|
||||
EFTS returns results under hits.hits as a list of objects,
|
||||
each containing _source with fields like file_date, form_type,
|
||||
entity_name, file_num, and period_of_report.
|
||||
"""
|
||||
hits_wrapper = data.get("hits", {})
|
||||
if not isinstance(hits_wrapper, dict):
|
||||
return []
|
||||
hits = hits_wrapper.get("hits", [])
|
||||
if isinstance(hits, list):
|
||||
return hits
|
||||
return []
|
||||
|
||||
def _total_hits(self, data: dict[str, Any]) -> int:
|
||||
"""Extract total hit count from EFTS response."""
|
||||
hits_wrapper = data.get("hits", {})
|
||||
if not isinstance(hits_wrapper, dict):
|
||||
return 0
|
||||
total = hits_wrapper.get("total", {})
|
||||
if isinstance(total, dict):
|
||||
return int(total.get("value", 0))
|
||||
if isinstance(total, int):
|
||||
return total
|
||||
return 0
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
ticker: str,
|
||||
error: str,
|
||||
elapsed_ms: float,
|
||||
http_status: int | None = None,
|
||||
raw: bytes = b"",
|
||||
) -> AdapterResult:
|
||||
"""Build an error AdapterResult for filings fetches."""
|
||||
return AdapterResult(
|
||||
source_type="filings_api",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=raw,
|
||||
content_hash="",
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
error=error,
|
||||
http_status=http_status,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={"provider": "sec_edgar"},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
"""Market data API adapter - fetches quotes, bars, and reference data."""
|
||||
"""Market data API adapter interface and concrete Polygon.io provider.
|
||||
|
||||
The MarketDataAdapter is the abstract interface for all market data providers.
|
||||
PolygonMarketAdapter is the first concrete implementation, targeting the
|
||||
Polygon.io REST API for previous-day bars, quotes, and ticker details.
|
||||
|
||||
Requirements: 2.1, 2.5, 3.1, 3.2, 3.3
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -12,48 +20,158 @@ logger = logging.getLogger("market_adapter")
|
||||
|
||||
|
||||
class MarketDataAdapter(BaseAdapter):
|
||||
"""Concrete adapter for a market data provider (e.g., Alpha Vantage, Polygon, Yahoo)."""
|
||||
"""Abstract interface for market data providers.
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str = ""):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
Subclasses implement fetch() for their specific market data API.
|
||||
"""
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "market_api"
|
||||
|
||||
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
|
||||
endpoint = config.get("endpoint", "/v2/aggs/ticker/{ticker}/prev")
|
||||
url = f"{self.base_url}{endpoint.format(ticker=ticker)}"
|
||||
params = config.get("params", {})
|
||||
if self.api_key:
|
||||
params["apiKey"] = self.api_key
|
||||
|
||||
class PolygonMarketAdapter(MarketDataAdapter):
|
||||
"""Concrete adapter for the Polygon.io REST API.
|
||||
|
||||
Supports:
|
||||
- Previous-day aggregate bars (/v2/aggs/ticker/{ticker}/prev)
|
||||
- Grouped daily bars (/v2/aggs/grouped/locale/us/market/stocks/{date})
|
||||
- Ticker details (/v3/reference/tickers/{ticker})
|
||||
|
||||
The endpoint is selected via the source config's "endpoint" field,
|
||||
defaulting to previous-day bars.
|
||||
"""
|
||||
|
||||
PREV_BARS = "/v2/aggs/ticker/{ticker}/prev"
|
||||
RANGE_BARS = "/v2/aggs/ticker/{ticker}/range/{multiplier}/{timespan}/{from_date}/{to_date}"
|
||||
TICKER_DETAILS = "/v3/reference/tickers/{ticker}"
|
||||
|
||||
def __init__(self, api_key: str, base_url: str = "https://api.polygon.io") -> None:
|
||||
self.api_key: str = api_key
|
||||
self.base_url: str = base_url.rstrip("/")
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch market data from Polygon.io for a given ticker.
|
||||
|
||||
Config options:
|
||||
endpoint: One of "prev_bars" (default), "range_bars", "ticker_details"
|
||||
multiplier: Bar multiplier for range queries (default 1)
|
||||
timespan: Bar timespan for range queries (default "day")
|
||||
from_date: Start date for range queries (YYYY-MM-DD)
|
||||
to_date: End date for range queries (YYYY-MM-DD)
|
||||
adjusted: Whether bars are adjusted for splits (default true)
|
||||
"""
|
||||
endpoint_key = config.get("endpoint", "prev_bars")
|
||||
url, params = self._build_request(ticker, endpoint_key, config)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = await client.get(url, params=params)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
|
||||
raw = resp.content
|
||||
data = resp.json()
|
||||
content_hash = hashlib.sha256(raw).hexdigest()
|
||||
|
||||
items = data.get("results", [data]) if isinstance(data, dict) else data
|
||||
items = self._extract_items(data, endpoint_key)
|
||||
|
||||
return AdapterResult(
|
||||
source_type="market_api",
|
||||
ticker=ticker,
|
||||
items=items if isinstance(items, list) else [items],
|
||||
items=items,
|
||||
raw_payload=raw,
|
||||
content_hash=content_hash,
|
||||
fetched_at=datetime.utcnow(),
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
http_status=resp.status_code,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={
|
||||
"provider": "polygon",
|
||||
"endpoint": endpoint_key,
|
||||
"results_count": data.get("resultsCount", len(items)),
|
||||
"request_id": data.get("request_id", ""),
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Polygon HTTP error for %s: %s", ticker, e)
|
||||
return self._error_result(
|
||||
ticker, str(e), elapsed_ms,
|
||||
http_status=e.response.status_code if e.response else None,
|
||||
raw=e.response.content if e.response else b"",
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Polygon timeout for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, f"timeout: {e}", elapsed_ms)
|
||||
except Exception as e:
|
||||
logger.error(f"Market fetch failed for {ticker}: {e}")
|
||||
return AdapterResult(
|
||||
source_type="market_api",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=b"",
|
||||
content_hash="",
|
||||
fetched_at=datetime.utcnow(),
|
||||
error=str(e),
|
||||
)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Polygon fetch failed for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, str(e), elapsed_ms)
|
||||
|
||||
def _build_request(
|
||||
self, ticker: str, endpoint_key: str, config: dict[str, Any]
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
"""Build the URL and query params for a Polygon request."""
|
||||
params: dict[str, str] = {"apiKey": self.api_key}
|
||||
|
||||
if endpoint_key == "range_bars":
|
||||
multiplier = str(config.get("multiplier", 1))
|
||||
timespan = config.get("timespan", "day")
|
||||
from_date = config.get("from_date", "")
|
||||
to_date = config.get("to_date", "")
|
||||
path = self.RANGE_BARS.format(
|
||||
ticker=ticker,
|
||||
multiplier=multiplier,
|
||||
timespan=timespan,
|
||||
from_date=from_date,
|
||||
to_date=to_date,
|
||||
)
|
||||
if config.get("adjusted") is not None:
|
||||
params["adjusted"] = str(config["adjusted"]).lower()
|
||||
if config.get("sort"):
|
||||
params["sort"] = config["sort"]
|
||||
if config.get("limit"):
|
||||
params["limit"] = str(config["limit"])
|
||||
elif endpoint_key == "ticker_details":
|
||||
path = self.TICKER_DETAILS.format(ticker=ticker)
|
||||
else:
|
||||
# Default: previous-day bars
|
||||
path = self.PREV_BARS.format(ticker=ticker)
|
||||
if config.get("adjusted") is not None:
|
||||
params["adjusted"] = str(config["adjusted"]).lower()
|
||||
|
||||
return f"{self.base_url}{path}", params
|
||||
|
||||
def _extract_items(self, data: dict[str, Any], endpoint_key: str) -> list[dict[str, Any]]:
|
||||
"""Extract the relevant items list from a Polygon response."""
|
||||
if endpoint_key == "ticker_details":
|
||||
results = data.get("results", {})
|
||||
return [results] if isinstance(results, dict) and results else []
|
||||
|
||||
# Aggregate endpoints return results as a list
|
||||
results = data.get("results", [])
|
||||
if isinstance(results, list):
|
||||
return results
|
||||
return [results] if results else []
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
ticker: str,
|
||||
error: str,
|
||||
elapsed_ms: float,
|
||||
http_status: int | None = None,
|
||||
raw: bytes = b"",
|
||||
) -> AdapterResult:
|
||||
"""Build an error AdapterResult."""
|
||||
return AdapterResult(
|
||||
source_type="market_api",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=raw,
|
||||
content_hash="",
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
error=error,
|
||||
http_status=http_status,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={"provider": "polygon"},
|
||||
)
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
"""News API adapter - fetches company-linked headlines and article metadata."""
|
||||
"""News API adapter interface and concrete Polygon.io news provider.
|
||||
|
||||
The NewsDataAdapter is the abstract interface for all news data providers.
|
||||
PolygonNewsAdapter is the first concrete implementation, targeting the
|
||||
Polygon.io REST API for company-linked news articles and headlines.
|
||||
|
||||
Requirements: 2.2, 2.5, 3.1, 3.2, 3.3
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
import time
|
||||
from abc import ABC
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -11,51 +20,147 @@ from .base import AdapterResult, BaseAdapter
|
||||
logger = logging.getLogger("news_adapter")
|
||||
|
||||
|
||||
class NewsApiAdapter(BaseAdapter):
|
||||
"""Concrete adapter for a news API provider."""
|
||||
class NewsDataAdapter(BaseAdapter, ABC):
|
||||
"""Abstract interface for news data providers.
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str = ""):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
Subclasses implement fetch() for their specific news API.
|
||||
source_type() is concrete here since all news adapters share the same type.
|
||||
"""
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "news_api"
|
||||
|
||||
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
|
||||
endpoint = config.get("endpoint", "/v2/everything")
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
params = config.get("params", {})
|
||||
params.setdefault("q", ticker)
|
||||
params.setdefault("sortBy", "publishedAt")
|
||||
params.setdefault("pageSize", 20)
|
||||
if self.api_key:
|
||||
params["apiKey"] = self.api_key
|
||||
|
||||
class PolygonNewsAdapter(NewsDataAdapter):
|
||||
"""Concrete adapter for the Polygon.io ticker news endpoint.
|
||||
|
||||
Supports:
|
||||
- Ticker news (/v2/reference/news?ticker={ticker})
|
||||
|
||||
Config options:
|
||||
limit: Max articles to return per request (default 20, max 1000)
|
||||
published_utc_gte: Only articles published on or after this date (YYYY-MM-DD)
|
||||
published_utc_lte: Only articles published on or before this date (YYYY-MM-DD)
|
||||
order: Sort order for results, "asc" or "desc" (default "desc")
|
||||
"""
|
||||
|
||||
NEWS_ENDPOINT = "/v2/reference/news"
|
||||
|
||||
def __init__(self, api_key: str, base_url: str = "https://api.polygon.io") -> None:
|
||||
self.api_key: str = api_key
|
||||
self.base_url: str = base_url.rstrip("/")
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch news articles from Polygon.io for a given ticker.
|
||||
|
||||
Args:
|
||||
ticker: The company ticker symbol.
|
||||
config: Source-specific configuration from the sources table.
|
||||
|
||||
Returns:
|
||||
AdapterResult with raw payload, parsed article items, and metadata.
|
||||
"""
|
||||
url, params = self._build_request(ticker, config)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = await client.get(url, params=params)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
|
||||
raw = resp.content
|
||||
data = resp.json()
|
||||
content_hash = hashlib.sha256(raw).hexdigest()
|
||||
items = self._extract_items(data)
|
||||
|
||||
articles = data.get("articles", [])
|
||||
return AdapterResult(
|
||||
source_type="news_api",
|
||||
ticker=ticker,
|
||||
items=articles,
|
||||
items=items,
|
||||
raw_payload=raw,
|
||||
content_hash=content_hash,
|
||||
fetched_at=datetime.utcnow(),
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
http_status=resp.status_code,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={
|
||||
"provider": "polygon",
|
||||
"results_count": data.get("count", len(items)),
|
||||
"next_url": data.get("next_url", ""),
|
||||
"request_id": data.get("request_id", ""),
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Polygon news HTTP error for %s: %s", ticker, e)
|
||||
return self._error_result(
|
||||
ticker, str(e), elapsed_ms,
|
||||
http_status=e.response.status_code if e.response else None,
|
||||
raw=e.response.content if e.response else b"",
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Polygon news timeout for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, f"timeout: {e}", elapsed_ms)
|
||||
except Exception as e:
|
||||
logger.error(f"News fetch failed for {ticker}: {e}")
|
||||
return AdapterResult(
|
||||
source_type="news_api",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=b"",
|
||||
content_hash="",
|
||||
fetched_at=datetime.utcnow(),
|
||||
error=str(e),
|
||||
)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Polygon news fetch failed for %s: %s", ticker, e)
|
||||
return self._error_result(ticker, str(e), elapsed_ms)
|
||||
|
||||
def _build_request(
|
||||
self, ticker: str, config: dict[str, Any]
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
"""Build the URL and query params for a Polygon news request."""
|
||||
params: dict[str, str] = {
|
||||
"apiKey": self.api_key,
|
||||
"ticker": ticker,
|
||||
}
|
||||
|
||||
limit = config.get("limit", 20)
|
||||
params["limit"] = str(min(int(limit), 1000))
|
||||
|
||||
if config.get("order"):
|
||||
params["order"] = config["order"]
|
||||
|
||||
if config.get("published_utc_gte"):
|
||||
params["published_utc.gte"] = config["published_utc_gte"]
|
||||
|
||||
if config.get("published_utc_lte"):
|
||||
params["published_utc.lte"] = config["published_utc_lte"]
|
||||
|
||||
url = f"{self.base_url}{self.NEWS_ENDPOINT}"
|
||||
return url, params
|
||||
|
||||
def _extract_items(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Extract the article list from a Polygon news response.
|
||||
|
||||
Polygon returns articles under the "results" key as a list of objects,
|
||||
each containing fields like id, publisher, title, article_url, tickers,
|
||||
published_utc, description, and keywords.
|
||||
"""
|
||||
results = data.get("results", [])
|
||||
if isinstance(results, list):
|
||||
return results
|
||||
return []
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
ticker: str,
|
||||
error: str,
|
||||
elapsed_ms: float,
|
||||
http_status: int | None = None,
|
||||
raw: bytes = b"",
|
||||
) -> AdapterResult:
|
||||
"""Build an error AdapterResult for news fetches."""
|
||||
return AdapterResult(
|
||||
source_type="news_api",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=raw,
|
||||
content_hash="",
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
error=error,
|
||||
http_status=http_status,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={"provider": "polygon"},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,603 @@
|
||||
"""Paper trading adapter - local order simulation and state sync.
|
||||
|
||||
Implements a fully local paper trading engine that simulates order
|
||||
execution without requiring a real broker API. Tracks positions,
|
||||
account balance, fills, and order events in-memory with PostgreSQL
|
||||
persistence for state sync and audit trail.
|
||||
|
||||
Requirements: 8.1, 8.3, 8.5, 2.4
|
||||
Design: Section 4.9 - Broker Adapter (paper mode)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.adapters.broker_adapter import (
|
||||
AccountInfo,
|
||||
BrokerDataAdapter,
|
||||
OrderEventType,
|
||||
OrderRequest,
|
||||
OrderResponse,
|
||||
OrderSide,
|
||||
OrderStatus,
|
||||
OrderType,
|
||||
PositionInfo,
|
||||
TradingMode,
|
||||
)
|
||||
from services.adapters.base import AdapterResult
|
||||
|
||||
logger = logging.getLogger("paper_trading")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory paper trading state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PaperPosition:
|
||||
"""Tracks a single paper position."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: float = 0.0,
|
||||
avg_entry_price: float = 0.0,
|
||||
realized_pnl: float = 0.0,
|
||||
) -> None:
|
||||
self.ticker = ticker
|
||||
self.quantity = quantity
|
||||
self.avg_entry_price = avg_entry_price
|
||||
self.realized_pnl = realized_pnl
|
||||
|
||||
def apply_fill(self, side: OrderSide, fill_qty: float, fill_price: float) -> float:
|
||||
"""Apply a fill to this position. Returns realized PnL from the fill."""
|
||||
realized = 0.0
|
||||
|
||||
if side == OrderSide.BUY:
|
||||
# Buying: average up the entry price
|
||||
total_cost = self.avg_entry_price * self.quantity + fill_price * fill_qty
|
||||
self.quantity += fill_qty
|
||||
if self.quantity > 0:
|
||||
self.avg_entry_price = total_cost / self.quantity
|
||||
else:
|
||||
# Selling: realize PnL on the sold shares
|
||||
if self.quantity > 0:
|
||||
sell_qty = min(fill_qty, self.quantity)
|
||||
realized = sell_qty * (fill_price - self.avg_entry_price)
|
||||
self.quantity -= sell_qty
|
||||
self.realized_pnl += realized
|
||||
if self.quantity <= 0:
|
||||
self.quantity = 0.0
|
||||
self.avg_entry_price = 0.0
|
||||
|
||||
return realized
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
return self.quantity > 0
|
||||
|
||||
def to_position_info(self, current_price: float | None = None) -> PositionInfo:
|
||||
"""Convert to a PositionInfo for the broker interface."""
|
||||
price = current_price if current_price is not None else self.avg_entry_price
|
||||
unrealized = (price - self.avg_entry_price) * self.quantity if self.quantity > 0 else 0.0
|
||||
market_value = price * self.quantity
|
||||
return PositionInfo(
|
||||
ticker=self.ticker,
|
||||
quantity=self.quantity,
|
||||
avg_entry_price=self.avg_entry_price,
|
||||
current_price=price,
|
||||
unrealized_pnl=round(unrealized, 4),
|
||||
market_value=round(market_value, 4),
|
||||
side="long" if self.quantity > 0 else "flat",
|
||||
)
|
||||
|
||||
|
||||
class PaperAccount:
|
||||
"""In-memory paper trading account state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account_id: str = "paper-default",
|
||||
initial_cash: float = 100_000.0,
|
||||
) -> None:
|
||||
self.account_id = account_id
|
||||
self.initial_cash = initial_cash
|
||||
self.cash = initial_cash
|
||||
self.positions: dict[str, PaperPosition] = {}
|
||||
self.orders: dict[str, OrderResponse] = {}
|
||||
self.order_events: list[dict[str, Any]] = []
|
||||
self._seen_idempotency_keys: dict[str, str] = {} # key -> order_id
|
||||
|
||||
@property
|
||||
def portfolio_value(self) -> float:
|
||||
position_value = sum(
|
||||
p.quantity * p.avg_entry_price for p in self.positions.values() if p.is_open
|
||||
)
|
||||
return self.cash + position_value
|
||||
|
||||
@property
|
||||
def buying_power(self) -> float:
|
||||
return self.cash
|
||||
|
||||
def get_position(self, ticker: str) -> PaperPosition:
|
||||
if ticker not in self.positions:
|
||||
self.positions[ticker] = PaperPosition(ticker=ticker)
|
||||
return self.positions[ticker]
|
||||
|
||||
def to_account_info(self) -> AccountInfo:
|
||||
return AccountInfo(
|
||||
account_id=self.account_id,
|
||||
buying_power=round(self.buying_power, 2),
|
||||
cash=round(self.cash, 2),
|
||||
portfolio_value=round(self.portfolio_value, 2),
|
||||
currency="USD",
|
||||
mode=TradingMode.PAPER,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Paper trading adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PaperTradingAdapter(BrokerDataAdapter):
|
||||
"""Local paper trading adapter that simulates order execution.
|
||||
|
||||
All orders are filled immediately at the estimated price (market orders)
|
||||
or at the limit/stop price when applicable. No real broker API is called.
|
||||
|
||||
Features:
|
||||
- Idempotent order submission via idempotency_key (Req 8.5)
|
||||
- Full order event trail for audit (Req 8.3)
|
||||
- Position tracking with average entry price
|
||||
- Cash balance management
|
||||
- State sync to/from PostgreSQL
|
||||
|
||||
The adapter operates in PAPER mode only and rejects any attempt
|
||||
to switch to LIVE mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account_id: str = "paper-default",
|
||||
initial_cash: float = 100_000.0,
|
||||
simulated_slippage_pct: float = 0.001,
|
||||
) -> None:
|
||||
super().__init__(mode=TradingMode.PAPER)
|
||||
self.account = PaperAccount(account_id=account_id, initial_cash=initial_cash)
|
||||
self.slippage_pct = simulated_slippage_pct
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "broker"
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch paper positions/account as a raw artifact snapshot."""
|
||||
endpoint = config.get("endpoint", "positions")
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if endpoint == "account":
|
||||
data = self.account.to_account_info().to_dict()
|
||||
items = [data]
|
||||
elif endpoint == "orders":
|
||||
items = [
|
||||
resp.to_dict()
|
||||
for resp in self.account.orders.values()
|
||||
if resp.ticker == ticker or ticker == "*"
|
||||
]
|
||||
else:
|
||||
pos = self.account.get_position(ticker)
|
||||
data = pos.to_position_info().to_dict()
|
||||
items = [data] if pos.is_open else []
|
||||
|
||||
raw = json.dumps(items).encode()
|
||||
return AdapterResult(
|
||||
source_type="broker",
|
||||
ticker=ticker,
|
||||
items=items,
|
||||
raw_payload=raw,
|
||||
content_hash="",
|
||||
fetched_at=now,
|
||||
metadata={"provider": "paper", "mode": "paper", "endpoint": endpoint},
|
||||
)
|
||||
|
||||
async def submit_order(self, order: OrderRequest) -> OrderResponse:
|
||||
"""Simulate order submission and immediate fill.
|
||||
|
||||
Idempotency: if the same idempotency_key was already used,
|
||||
return the original response (Req 8.5).
|
||||
"""
|
||||
# Idempotency check
|
||||
existing_id = self.account._seen_idempotency_keys.get(order.idempotency_key)
|
||||
if existing_id and existing_id in self.account.orders:
|
||||
logger.info("Duplicate order key %s — returning cached response", order.idempotency_key)
|
||||
return self.account.orders[existing_id]
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
order_id = str(uuid.uuid4())
|
||||
|
||||
# Determine fill price based on order type
|
||||
fill_price = self._compute_fill_price(order)
|
||||
|
||||
# Check if we have enough cash for buys
|
||||
if order.side == OrderSide.BUY:
|
||||
required_cash = fill_price * order.quantity
|
||||
if required_cash > self.account.cash:
|
||||
resp = OrderResponse(
|
||||
broker_order_id=order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
quantity=order.quantity,
|
||||
submitted_at=now,
|
||||
error=f"Insufficient cash: need {required_cash:.2f}, have {self.account.cash:.2f}",
|
||||
)
|
||||
self._record_event(order_id, OrderEventType.REJECTED, resp.to_dict(), now)
|
||||
self.account.orders[order_id] = resp
|
||||
self.account._seen_idempotency_keys[order.idempotency_key] = order_id
|
||||
return resp
|
||||
|
||||
# Check if we have enough shares for sells
|
||||
if order.side == OrderSide.SELL:
|
||||
pos = self.account.get_position(order.ticker)
|
||||
if pos.quantity < order.quantity:
|
||||
resp = OrderResponse(
|
||||
broker_order_id=order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
quantity=order.quantity,
|
||||
submitted_at=now,
|
||||
error=f"Insufficient shares: need {order.quantity}, have {pos.quantity}",
|
||||
)
|
||||
self._record_event(order_id, OrderEventType.REJECTED, resp.to_dict(), now)
|
||||
self.account.orders[order_id] = resp
|
||||
self.account._seen_idempotency_keys[order.idempotency_key] = order_id
|
||||
return resp
|
||||
|
||||
# Simulate immediate fill
|
||||
position = self.account.get_position(order.ticker)
|
||||
realized_pnl = position.apply_fill(order.side, order.quantity, fill_price)
|
||||
|
||||
# Update cash
|
||||
if order.side == OrderSide.BUY:
|
||||
self.account.cash -= fill_price * order.quantity
|
||||
else:
|
||||
self.account.cash += fill_price * order.quantity
|
||||
|
||||
resp = OrderResponse(
|
||||
broker_order_id=order_id,
|
||||
status=OrderStatus.FILLED,
|
||||
ticker=order.ticker,
|
||||
side=order.side,
|
||||
quantity=order.quantity,
|
||||
filled_quantity=order.quantity,
|
||||
filled_avg_price=fill_price,
|
||||
submitted_at=now,
|
||||
raw_response={
|
||||
"realized_pnl": round(realized_pnl, 4),
|
||||
"cash_after": round(self.account.cash, 2),
|
||||
"position_qty_after": position.quantity,
|
||||
"simulated": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Record events
|
||||
self._record_event(order_id, OrderEventType.SUBMITTED, {"ticker": order.ticker}, now)
|
||||
self._record_event(order_id, OrderEventType.ACCEPTED, {"ticker": order.ticker}, now)
|
||||
self._record_event(order_id, OrderEventType.FILL, {
|
||||
"fill_price": fill_price,
|
||||
"fill_qty": order.quantity,
|
||||
"realized_pnl": round(realized_pnl, 4),
|
||||
}, now)
|
||||
|
||||
self.account.orders[order_id] = resp
|
||||
self.account._seen_idempotency_keys[order.idempotency_key] = order_id
|
||||
|
||||
logger.info(
|
||||
"Paper fill: %s %s %.0f %s @ %.2f | cash=%.2f pnl=%.4f",
|
||||
order_id[:8], order.side.value, order.quantity,
|
||||
order.ticker, fill_price, self.account.cash, realized_pnl,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
async def cancel_order(self, broker_order_id: str) -> OrderResponse:
|
||||
"""Cancel a paper order. Only pending orders can be cancelled."""
|
||||
existing = self.account.orders.get(broker_order_id)
|
||||
if existing is None:
|
||||
return OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker="",
|
||||
side=OrderSide.BUY,
|
||||
quantity=0,
|
||||
error=f"Order {broker_order_id} not found",
|
||||
)
|
||||
|
||||
# Paper orders fill immediately, so they can't be cancelled
|
||||
if existing.status == OrderStatus.FILLED:
|
||||
return OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker=existing.ticker,
|
||||
side=existing.side,
|
||||
quantity=existing.quantity,
|
||||
error="Cannot cancel a filled order",
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
cancelled = OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.CANCELLED,
|
||||
ticker=existing.ticker,
|
||||
side=existing.side,
|
||||
quantity=existing.quantity,
|
||||
submitted_at=existing.submitted_at,
|
||||
)
|
||||
self.account.orders[broker_order_id] = cancelled
|
||||
self._record_event(broker_order_id, OrderEventType.CANCELLED, {}, now)
|
||||
return cancelled
|
||||
|
||||
async def get_order_status(self, broker_order_id: str) -> OrderResponse:
|
||||
"""Get the status of a paper order."""
|
||||
existing = self.account.orders.get(broker_order_id)
|
||||
if existing is None:
|
||||
return OrderResponse(
|
||||
broker_order_id=broker_order_id,
|
||||
status=OrderStatus.REJECTED,
|
||||
ticker="",
|
||||
side=OrderSide.BUY,
|
||||
quantity=0,
|
||||
error=f"Order {broker_order_id} not found",
|
||||
)
|
||||
return existing
|
||||
|
||||
async def get_positions(self) -> list[PositionInfo]:
|
||||
"""Get all open paper positions."""
|
||||
return [
|
||||
p.to_position_info()
|
||||
for p in self.account.positions.values()
|
||||
if p.is_open
|
||||
]
|
||||
|
||||
async def get_account(self) -> AccountInfo:
|
||||
"""Get paper account summary."""
|
||||
return self.account.to_account_info()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _compute_fill_price(self, order: OrderRequest) -> float:
|
||||
"""Determine the simulated fill price for an order.
|
||||
|
||||
Market orders use the limit_price as a proxy (or 0 if not set).
|
||||
Limit orders fill at the limit price.
|
||||
Stop orders fill at the stop price.
|
||||
A small slippage is applied to market orders.
|
||||
"""
|
||||
if order.order_type == OrderType.LIMIT and order.limit_price is not None:
|
||||
return order.limit_price
|
||||
if order.order_type == OrderType.STOP and order.stop_price is not None:
|
||||
return order.stop_price
|
||||
if order.order_type == OrderType.STOP_LIMIT and order.limit_price is not None:
|
||||
return order.limit_price
|
||||
|
||||
# Market order: use limit_price as estimate, or a default
|
||||
base_price = order.limit_price if order.limit_price is not None else 100.0
|
||||
if order.side == OrderSide.BUY:
|
||||
return round(base_price * (1 + self.slippage_pct), 4)
|
||||
return round(base_price * (1 - self.slippage_pct), 4)
|
||||
|
||||
def _record_event(
|
||||
self,
|
||||
order_id: str,
|
||||
event_type: OrderEventType,
|
||||
data: dict[str, Any],
|
||||
timestamp: datetime,
|
||||
) -> None:
|
||||
"""Record an order event for audit trail."""
|
||||
self.account.order_events.append({
|
||||
"order_id": order_id,
|
||||
"event_type": event_type.value,
|
||||
"data": data,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State sync: persist and restore paper trading state to/from PostgreSQL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# SQL for persisting paper orders to the orders table
|
||||
_INSERT_PAPER_ORDER = """
|
||||
INSERT INTO orders (
|
||||
id, recommendation_id, broker_account_id, ticker, side, order_type,
|
||||
quantity, limit_price, stop_price, status, idempotency_key,
|
||||
broker_order_id, decision_trace, submitted_at, filled_at,
|
||||
fill_price, fill_quantity
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10, $11,
|
||||
$12, $13::jsonb, $14, $15,
|
||||
$16, $17
|
||||
)
|
||||
ON CONFLICT (idempotency_key) DO NOTHING
|
||||
"""
|
||||
|
||||
_INSERT_PAPER_ORDER_EVENT = """
|
||||
INSERT INTO order_events (order_id, event_type, data, broker_timestamp)
|
||||
VALUES ($1::uuid, $2, $3::jsonb, $4)
|
||||
"""
|
||||
|
||||
_UPSERT_PAPER_POSITION = """
|
||||
INSERT INTO positions (broker_account_id, ticker, quantity, avg_entry_price, realized_pnl, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (broker_account_id, ticker)
|
||||
DO UPDATE SET
|
||||
quantity = EXCLUDED.quantity,
|
||||
avg_entry_price = EXCLUDED.avg_entry_price,
|
||||
realized_pnl = EXCLUDED.realized_pnl,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
"""
|
||||
|
||||
_UPSERT_PAPER_ACCOUNT = """
|
||||
INSERT INTO broker_accounts (id, provider, account_id, mode, config, active)
|
||||
VALUES ($1::uuid, 'paper', $2, 'paper', $3::jsonb, TRUE)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
config = EXCLUDED.config,
|
||||
active = TRUE
|
||||
"""
|
||||
|
||||
_LOAD_PAPER_POSITIONS = """
|
||||
SELECT ticker, quantity, avg_entry_price, COALESCE(realized_pnl, 0) AS realized_pnl
|
||||
FROM positions
|
||||
WHERE broker_account_id = $1 AND quantity > 0
|
||||
"""
|
||||
|
||||
_LOAD_PAPER_ACCOUNT_CONFIG = """
|
||||
SELECT config FROM broker_accounts
|
||||
WHERE account_id = $1 AND mode = 'paper' AND active = TRUE
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_LOAD_PAPER_ORDERS = """
|
||||
SELECT
|
||||
id, ticker, side, order_type, quantity, status,
|
||||
idempotency_key, broker_order_id, fill_price, fill_quantity,
|
||||
submitted_at
|
||||
FROM orders
|
||||
WHERE broker_account_id = (
|
||||
SELECT id FROM broker_accounts WHERE account_id = $1 AND mode = 'paper' LIMIT 1
|
||||
)
|
||||
ORDER BY submitted_at DESC
|
||||
LIMIT 500
|
||||
"""
|
||||
|
||||
|
||||
async def sync_state_to_db(
|
||||
adapter: PaperTradingAdapter,
|
||||
pool: asyncpg.Pool,
|
||||
broker_account_uuid: str | None = None,
|
||||
) -> None:
|
||||
"""Persist the current paper trading state to PostgreSQL.
|
||||
|
||||
Writes:
|
||||
- broker_accounts row for the paper account
|
||||
- positions rows for all open positions
|
||||
- orders rows for all orders (idempotent via ON CONFLICT)
|
||||
- order_events for audit trail
|
||||
|
||||
This enables state recovery after restarts and provides the
|
||||
full execution audit trail (Requirement 8.3).
|
||||
"""
|
||||
acct = adapter.account
|
||||
now = datetime.now(timezone.utc)
|
||||
acct_uuid = broker_account_uuid or str(uuid.uuid5(uuid.NAMESPACE_DNS, acct.account_id))
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
# 1. Upsert broker account
|
||||
config_json = json.dumps({
|
||||
"initial_cash": acct.initial_cash,
|
||||
"current_cash": round(acct.cash, 2),
|
||||
"portfolio_value": round(acct.portfolio_value, 2),
|
||||
"slippage_pct": adapter.slippage_pct,
|
||||
})
|
||||
await conn.execute(_UPSERT_PAPER_ACCOUNT, acct_uuid, acct.account_id, config_json)
|
||||
|
||||
# 2. Upsert positions
|
||||
for ticker, pos in acct.positions.items():
|
||||
await conn.execute(
|
||||
_UPSERT_PAPER_POSITION,
|
||||
acct_uuid, ticker,
|
||||
pos.quantity, pos.avg_entry_price, pos.realized_pnl,
|
||||
now,
|
||||
)
|
||||
|
||||
# 3. Insert orders (idempotent)
|
||||
for order_id, resp in acct.orders.items():
|
||||
filled_at = now if resp.status == OrderStatus.FILLED else None
|
||||
await conn.execute(
|
||||
_INSERT_PAPER_ORDER,
|
||||
order_id,
|
||||
None, # recommendation_id
|
||||
acct_uuid,
|
||||
resp.ticker,
|
||||
resp.side.value,
|
||||
"market", # paper orders are always market-simulated
|
||||
resp.quantity,
|
||||
resp.filled_avg_price, # limit_price
|
||||
None, # stop_price
|
||||
resp.status.value,
|
||||
order_id, # use order_id as idempotency_key fallback
|
||||
order_id,
|
||||
json.dumps(resp.raw_response),
|
||||
resp.submitted_at,
|
||||
filled_at,
|
||||
resp.filled_avg_price,
|
||||
resp.filled_quantity,
|
||||
)
|
||||
|
||||
# 4. Insert order events
|
||||
for event in acct.order_events:
|
||||
await conn.execute(
|
||||
_INSERT_PAPER_ORDER_EVENT,
|
||||
event["order_id"],
|
||||
event["event_type"],
|
||||
json.dumps(event["data"]),
|
||||
datetime.fromisoformat(event["timestamp"]),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Synced paper state to DB: account=%s positions=%d orders=%d events=%d",
|
||||
acct.account_id, len(acct.positions), len(acct.orders), len(acct.order_events),
|
||||
)
|
||||
|
||||
# Clear events after sync to avoid re-inserting
|
||||
acct.order_events.clear()
|
||||
|
||||
|
||||
async def load_state_from_db(
|
||||
adapter: PaperTradingAdapter,
|
||||
pool: asyncpg.Pool,
|
||||
) -> bool:
|
||||
"""Restore paper trading state from PostgreSQL.
|
||||
|
||||
Loads positions and account config from the DB so the adapter
|
||||
can resume after a restart. Returns True if state was found.
|
||||
"""
|
||||
acct = adapter.account
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Load account config
|
||||
row = await conn.fetchrow(_LOAD_PAPER_ACCOUNT_CONFIG, acct.account_id)
|
||||
if row is None:
|
||||
logger.info("No saved paper account state for %s", acct.account_id)
|
||||
return False
|
||||
|
||||
config = json.loads(row["config"]) if isinstance(row["config"], str) else row["config"]
|
||||
acct.cash = float(config.get("current_cash", acct.initial_cash))
|
||||
|
||||
# Load positions
|
||||
pos_rows = await conn.fetch(_LOAD_PAPER_POSITIONS, acct.account_id)
|
||||
for pr in pos_rows:
|
||||
ticker = pr["ticker"]
|
||||
acct.positions[ticker] = PaperPosition(
|
||||
ticker=ticker,
|
||||
quantity=float(pr["quantity"]),
|
||||
avg_entry_price=float(pr["avg_entry_price"] or 0),
|
||||
realized_pnl=float(pr["realized_pnl"]),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loaded paper state from DB: account=%s cash=%.2f positions=%d",
|
||||
acct.account_id, acct.cash, len(acct.positions),
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Resilient adapter wrapper with rate-limit coordination, retries, and backoff.
|
||||
|
||||
Wraps any BaseAdapter with:
|
||||
- Per-source-type rate limiting via Redis (distributed across workers)
|
||||
- Exponential backoff with jitter on retryable failures
|
||||
- Configurable retry counts and retryable HTTP status codes
|
||||
- Graceful degradation when Redis is unavailable
|
||||
|
||||
Requirements: 2.5, 3.4
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.shared.redis_keys import rate_limit_key
|
||||
|
||||
from .base import AdapterResult, BaseAdapter
|
||||
|
||||
logger = logging.getLogger("resilient_adapter")
|
||||
|
||||
# HTTP status codes that are safe to retry
|
||||
RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""Configuration for retry and rate-limit behavior."""
|
||||
|
||||
max_retries: int = 3
|
||||
base_delay: float = 1.0
|
||||
max_delay: float = 60.0
|
||||
jitter_factor: float = 0.5
|
||||
retryable_status_codes: frozenset[int] = RETRYABLE_STATUS_CODES
|
||||
# Rate limit: max requests per window per source type
|
||||
rate_limit_max: int = 30
|
||||
rate_limit_window_seconds: int = 60
|
||||
|
||||
|
||||
# Sensible defaults per source type
|
||||
DEFAULT_RETRY_CONFIGS: dict[str, RetryConfig] = {
|
||||
"market_api": RetryConfig(max_retries=3, rate_limit_max=30),
|
||||
"news_api": RetryConfig(max_retries=3, rate_limit_max=20),
|
||||
"filings_api": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0),
|
||||
"web_scrape": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0),
|
||||
"broker": RetryConfig(max_retries=2, rate_limit_max=60, base_delay=0.5),
|
||||
}
|
||||
|
||||
|
||||
def compute_delay(attempt: int, config: RetryConfig) -> float:
|
||||
"""Compute backoff delay with jitter for a given attempt number."""
|
||||
exp_delay = config.base_delay * (2 ** attempt)
|
||||
capped = min(exp_delay, config.max_delay)
|
||||
jitter = capped * config.jitter_factor * random.random()
|
||||
return capped + jitter
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryStats:
|
||||
"""Tracks retry statistics for observability."""
|
||||
|
||||
attempts: int = 0
|
||||
total_delay: float = 0.0
|
||||
rate_limited_waits: int = 0
|
||||
last_error: str | None = None
|
||||
retryable: bool = False
|
||||
|
||||
|
||||
class ResilientAdapter:
|
||||
"""Wraps a BaseAdapter with rate-limit coordination, retries, and backoff.
|
||||
|
||||
Usage:
|
||||
adapter = PolygonMarketAdapter(api_key="...")
|
||||
resilient = ResilientAdapter(adapter, redis=rds)
|
||||
result = await resilient.fetch(ticker, config)
|
||||
|
||||
If redis is None, rate limiting is skipped (local dev / testing).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: BaseAdapter,
|
||||
redis: aioredis.Redis | None = None,
|
||||
retry_config: RetryConfig | None = None,
|
||||
) -> None:
|
||||
self._adapter = adapter
|
||||
self._redis = redis
|
||||
source_type = adapter.source_type()
|
||||
self._config = retry_config or DEFAULT_RETRY_CONFIGS.get(
|
||||
source_type, RetryConfig()
|
||||
)
|
||||
|
||||
@property
|
||||
def adapter(self) -> BaseAdapter:
|
||||
"""Access the underlying adapter."""
|
||||
return self._adapter
|
||||
|
||||
@property
|
||||
def config(self) -> RetryConfig:
|
||||
return self._config
|
||||
|
||||
def source_type(self) -> str:
|
||||
return self._adapter.source_type()
|
||||
|
||||
async def _check_rate_limit(self) -> float:
|
||||
"""Check distributed rate limit via Redis.
|
||||
|
||||
Returns 0.0 if allowed, or the number of seconds to wait.
|
||||
"""
|
||||
if self._redis is None:
|
||||
return 0.0
|
||||
|
||||
source_type = self._adapter.source_type()
|
||||
window_sec = self._config.rate_limit_window_seconds
|
||||
# Use a time-bucketed key so counters auto-expire
|
||||
bucket = int(time.time()) // window_sec
|
||||
key = rate_limit_key(source_type, str(bucket))
|
||||
|
||||
try:
|
||||
count = await self._redis.incr(key)
|
||||
if count == 1:
|
||||
await self._redis.expire(key, window_sec * 2)
|
||||
if count > self._config.rate_limit_max:
|
||||
# Over limit — compute how long until the window rolls over
|
||||
elapsed_in_window = time.time() % window_sec
|
||||
wait = window_sec - elapsed_in_window
|
||||
return max(wait, 0.5)
|
||||
except Exception:
|
||||
# Redis unavailable — degrade gracefully, allow the request
|
||||
logger.warning("Redis rate-limit check failed, allowing request")
|
||||
return 0.0
|
||||
|
||||
def _is_retryable(self, result: AdapterResult) -> bool:
|
||||
"""Determine if a failed result is worth retrying."""
|
||||
if result.ok:
|
||||
return False
|
||||
# Retry on known retryable HTTP status codes
|
||||
if result.http_status and result.http_status in self._config.retryable_status_codes:
|
||||
return True
|
||||
# Retry on timeouts
|
||||
if result.error and "timeout" in result.error.lower():
|
||||
return True
|
||||
# Retry on connection errors
|
||||
if result.error and any(
|
||||
kw in result.error.lower()
|
||||
for kw in ("connection", "connect", "reset", "refused")
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_retry_after(self, result: AdapterResult) -> float | None:
|
||||
"""Extract Retry-After hint from result metadata if present."""
|
||||
retry_after = result.metadata.get("retry_after")
|
||||
if retry_after is not None:
|
||||
try:
|
||||
return float(retry_after)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch with rate-limit coordination, retries, and exponential backoff.
|
||||
|
||||
Returns the AdapterResult from the underlying adapter. On retryable
|
||||
failures, retries up to max_retries times with exponential backoff
|
||||
and jitter. Rate-limit waits are applied before each attempt.
|
||||
|
||||
The returned result's metadata includes retry stats under the
|
||||
"retry_stats" key.
|
||||
"""
|
||||
stats = RetryStats()
|
||||
last_result: AdapterResult | None = None
|
||||
|
||||
for attempt in range(self._config.max_retries + 1):
|
||||
stats.attempts = attempt + 1
|
||||
|
||||
# Rate limit check
|
||||
wait = await self._check_rate_limit()
|
||||
if wait > 0:
|
||||
stats.rate_limited_waits += 1
|
||||
logger.info(
|
||||
"Rate limited for %s/%s, waiting %.1fs",
|
||||
self.source_type(), ticker, wait,
|
||||
)
|
||||
stats.total_delay += wait
|
||||
await asyncio.sleep(wait)
|
||||
|
||||
# Execute the fetch
|
||||
result = await self._adapter.fetch(ticker, config)
|
||||
last_result = result
|
||||
|
||||
# Success — attach stats and return
|
||||
if result.ok:
|
||||
result.metadata["retry_stats"] = {
|
||||
"attempts": stats.attempts,
|
||||
"total_delay": round(stats.total_delay, 2),
|
||||
"rate_limited_waits": stats.rate_limited_waits,
|
||||
}
|
||||
return result
|
||||
|
||||
# Check if retryable
|
||||
if not self._is_retryable(result):
|
||||
stats.last_error = result.error
|
||||
stats.retryable = False
|
||||
break
|
||||
|
||||
stats.retryable = True
|
||||
stats.last_error = result.error
|
||||
|
||||
# Don't sleep after the last attempt
|
||||
if attempt < self._config.max_retries:
|
||||
# Respect Retry-After header for 429s
|
||||
retry_after = self._extract_retry_after(result)
|
||||
if result.http_status == 429 and retry_after is not None:
|
||||
delay = min(retry_after, self._config.max_delay)
|
||||
else:
|
||||
delay = compute_delay(attempt, self._config)
|
||||
|
||||
logger.info(
|
||||
"Retrying %s/%s (attempt %d/%d) after %.1fs: %s",
|
||||
self.source_type(), ticker, attempt + 1,
|
||||
self._config.max_retries + 1, delay, result.error,
|
||||
)
|
||||
stats.total_delay += delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries exhausted — return last result with stats
|
||||
assert last_result is not None
|
||||
last_result.metadata["retry_stats"] = {
|
||||
"attempts": stats.attempts,
|
||||
"total_delay": round(stats.total_delay, 2),
|
||||
"rate_limited_waits": stats.rate_limited_waits,
|
||||
"exhausted": True,
|
||||
"last_error": stats.last_error,
|
||||
}
|
||||
return last_result
|
||||
@@ -0,0 +1,321 @@
|
||||
"""Web scrape adapter for curated URLs and article pages.
|
||||
|
||||
Fetches full article HTML from curated URLs (investor relations pages,
|
||||
press releases, earnings transcripts, etc.) using BeautifulSoup + requests
|
||||
with retry adapters, content hashing, boilerplate awareness, and quality scoring.
|
||||
|
||||
Inspired by Noctipede crawler patterns: BeautifulSoup + requests with retry
|
||||
adapters, content hashing, boilerplate stripping, quality scoring.
|
||||
|
||||
Requirements: 1.2, 2.5, 3.1, 3.2, 3.3, 3.4
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from services.shared.content import content_hash, normalize_url
|
||||
|
||||
from .base import AdapterResult, BaseAdapter
|
||||
|
||||
logger = logging.getLogger("web_scrape_adapter")
|
||||
|
||||
# Default request settings
|
||||
DEFAULT_TIMEOUT = 30
|
||||
DEFAULT_USER_AGENT = "StonksOracle/1.0 (+https://stonks-oracle.celestium.life)"
|
||||
MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB cap
|
||||
|
||||
|
||||
def extract_metadata_from_html(html: str, url: str) -> dict[str, str | None]:
|
||||
"""Extract title, author, publisher, published date, and links from HTML."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
meta: dict[str, str | None] = {}
|
||||
|
||||
# Title: prefer og:title, then <title>
|
||||
og_title = soup.find("meta", property="og:title")
|
||||
if og_title and og_title.get("content"):
|
||||
content = og_title["content"]
|
||||
meta["title"] = content.strip() if isinstance(content, str) else ""
|
||||
elif soup.title and soup.title.string:
|
||||
meta["title"] = soup.title.string.strip()
|
||||
else:
|
||||
meta["title"] = ""
|
||||
|
||||
# Author
|
||||
author_tag = soup.find("meta", attrs={"name": "author"})
|
||||
if author_tag and author_tag.get("content"):
|
||||
content = author_tag["content"]
|
||||
meta["author"] = content.strip() if isinstance(content, str) else ""
|
||||
else:
|
||||
meta["author"] = ""
|
||||
|
||||
# Publisher: og:site_name
|
||||
site_name = soup.find("meta", property="og:site_name")
|
||||
if site_name and site_name.get("content"):
|
||||
content = site_name["content"]
|
||||
meta["publisher"] = content.strip() if isinstance(content, str) else ""
|
||||
else:
|
||||
meta["publisher"] = urlparse(url).hostname or ""
|
||||
|
||||
# Published date: article:published_time or datePublished
|
||||
pub_time = soup.find("meta", property="article:published_time")
|
||||
if pub_time and pub_time.get("content"):
|
||||
content = pub_time["content"]
|
||||
meta["published_at"] = content.strip() if isinstance(content, str) else None
|
||||
else:
|
||||
# Try JSON-LD datePublished
|
||||
for script in soup.find_all("script", type="application/ld+json"):
|
||||
if script.string and "datePublished" in script.string:
|
||||
try:
|
||||
ld = json.loads(script.string)
|
||||
if isinstance(ld, dict) and "datePublished" in ld:
|
||||
meta["published_at"] = str(ld["datePublished"])
|
||||
break
|
||||
if isinstance(ld, list):
|
||||
for item in ld:
|
||||
if isinstance(item, dict) and "datePublished" in item:
|
||||
meta["published_at"] = str(item["datePublished"])
|
||||
break
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if "published_at" not in meta:
|
||||
meta["published_at"] = None
|
||||
|
||||
# Canonical URL
|
||||
canonical = soup.find("link", rel="canonical")
|
||||
if canonical and canonical.get("href"):
|
||||
href = canonical["href"]
|
||||
meta["canonical_url"] = str(href) if href else normalize_url(url)
|
||||
else:
|
||||
og_url = soup.find("meta", property="og:url")
|
||||
if og_url and og_url.get("content"):
|
||||
content = og_url["content"]
|
||||
meta["canonical_url"] = str(content) if content else normalize_url(url)
|
||||
else:
|
||||
meta["canonical_url"] = normalize_url(url)
|
||||
|
||||
# Language
|
||||
html_tag = soup.find("html")
|
||||
if html_tag and html_tag.get("lang"):
|
||||
lang = html_tag["lang"]
|
||||
meta["language"] = str(lang)[:5] if lang else "en"
|
||||
else:
|
||||
meta["language"] = "en"
|
||||
|
||||
# Description for summary
|
||||
desc = soup.find("meta", property="og:description") or soup.find(
|
||||
"meta", attrs={"name": "description"}
|
||||
)
|
||||
if desc and desc.get("content"):
|
||||
content = desc["content"]
|
||||
meta["description"] = content.strip() if isinstance(content, str) else ""
|
||||
else:
|
||||
meta["description"] = ""
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
def extract_body_text(html: str) -> str:
|
||||
"""Extract main body text from HTML, stripping nav/footer/ads."""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
|
||||
# Remove non-content elements
|
||||
for tag in soup.find_all(
|
||||
["script", "style", "nav", "footer", "header", "aside", "iframe", "noscript"]
|
||||
):
|
||||
tag.decompose()
|
||||
|
||||
# Try to find article body
|
||||
article = soup.find("article")
|
||||
if not article:
|
||||
for div in soup.find_all("div"):
|
||||
cls = div.get("class", [])
|
||||
cls_str = " ".join(cls) if isinstance(cls, list) else str(cls) if cls else ""
|
||||
if any(kw in cls_str for kw in ["article-body", "post-content", "entry-content", "story-body"]):
|
||||
article = div
|
||||
break
|
||||
|
||||
if article:
|
||||
text = article.get_text(separator="\n", strip=True)
|
||||
else:
|
||||
# Fallback: use body
|
||||
body = soup.find("body")
|
||||
text = body.get_text(separator="\n", strip=True) if body else soup.get_text(separator="\n", strip=True)
|
||||
|
||||
# Collapse whitespace
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class WebScrapeAdapter(BaseAdapter):
|
||||
"""Adapter for fetching curated web pages and article URLs.
|
||||
|
||||
Config options (from source config):
|
||||
urls: List of URLs to scrape for this company
|
||||
url: Single URL to scrape (alternative to urls)
|
||||
timeout: Request timeout in seconds (default 30)
|
||||
user_agent: Custom user agent string
|
||||
follow_links: Whether to follow article links from index pages (default False)
|
||||
max_pages: Max pages to fetch per cycle (default 5)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "web_scrape"
|
||||
|
||||
def bucket_name(self) -> str:
|
||||
"""Web scrape artifacts go to the news raw bucket."""
|
||||
return "stonks-raw-news"
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch HTML from curated URLs for a given ticker.
|
||||
|
||||
Supports both single URL and multi-URL configs. Each URL is fetched,
|
||||
HTML is preserved as raw payload, and metadata is extracted.
|
||||
"""
|
||||
urls = config.get("urls", [])
|
||||
if not urls and config.get("url"):
|
||||
urls = [config["url"]]
|
||||
|
||||
if not urls:
|
||||
return self._error_result(ticker, "No URLs configured for web_scrape source", 0)
|
||||
|
||||
timeout = config.get("timeout", DEFAULT_TIMEOUT)
|
||||
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
|
||||
max_pages = min(config.get("max_pages", 5), 20)
|
||||
|
||||
items: list[dict[str, Any]] = []
|
||||
all_raw: list[bytes] = []
|
||||
total_elapsed = 0.0
|
||||
errors: list[str] = []
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
headers={"User-Agent": user_agent},
|
||||
) as client:
|
||||
for url in urls[:max_pages]:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = await client.get(url)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
total_elapsed += elapsed_ms
|
||||
resp.raise_for_status()
|
||||
|
||||
# Content length guard
|
||||
if len(resp.content) > MAX_CONTENT_LENGTH:
|
||||
errors.append(f"Content too large for {url}: {len(resp.content)} bytes")
|
||||
continue
|
||||
|
||||
html = resp.text
|
||||
raw_bytes = resp.content
|
||||
all_raw.append(raw_bytes)
|
||||
|
||||
item_content_hash = content_hash(raw_bytes)
|
||||
meta = extract_metadata_from_html(html, url)
|
||||
body_text = extract_body_text(html)
|
||||
|
||||
item: dict[str, Any] = {
|
||||
"url": url,
|
||||
"canonical_url": meta.get("canonical_url", normalize_url(url)),
|
||||
"title": meta.get("title", ""),
|
||||
"author": meta.get("author", ""),
|
||||
"publisher": meta.get("publisher", ""),
|
||||
"published_at": meta.get("published_at"),
|
||||
"language": meta.get("language", "en"),
|
||||
"description": meta.get("description", ""),
|
||||
"content_hash": item_content_hash,
|
||||
"body_text": body_text,
|
||||
"body_length": len(body_text),
|
||||
"html_length": len(html),
|
||||
"http_status": resp.status_code,
|
||||
"response_time_ms": round(elapsed_ms, 1),
|
||||
}
|
||||
items.append(item)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
total_elapsed += elapsed_ms
|
||||
status = e.response.status_code if e.response else None
|
||||
errors.append(f"HTTP {status} for {url}: {e}")
|
||||
logger.warning("Scrape HTTP error for %s/%s: %s", ticker, url, e)
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
total_elapsed += elapsed_ms
|
||||
errors.append(f"Timeout for {url}: {e}")
|
||||
logger.warning("Scrape timeout for %s/%s: %s", ticker, url, e)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
total_elapsed += elapsed_ms
|
||||
errors.append(f"Error for {url}: {e}")
|
||||
logger.warning("Scrape error for %s/%s: %s", ticker, url, e)
|
||||
|
||||
if not items:
|
||||
error_msg = "; ".join(errors) if errors else "No pages fetched"
|
||||
return self._error_result(ticker, error_msg, total_elapsed)
|
||||
|
||||
# Combine all raw payloads into a single artifact
|
||||
combined_raw = json.dumps({
|
||||
"ticker": ticker,
|
||||
"fetched_at": datetime.now(timezone.utc).isoformat(),
|
||||
"pages": [
|
||||
{
|
||||
"url": item["url"],
|
||||
"content_hash": item["content_hash"],
|
||||
"html_length": item["html_length"],
|
||||
"body_length": item["body_length"],
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
"errors": errors,
|
||||
}).encode("utf-8")
|
||||
|
||||
combined_hash = content_hash(
|
||||
b"".join(item["content_hash"].encode() for item in items)
|
||||
)
|
||||
|
||||
return AdapterResult(
|
||||
source_type="web_scrape",
|
||||
ticker=ticker,
|
||||
items=items,
|
||||
raw_payload=combined_raw,
|
||||
content_hash=combined_hash,
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
http_status=200,
|
||||
response_time_ms=round(total_elapsed, 1),
|
||||
metadata={
|
||||
"provider": "web_scrape",
|
||||
"pages_fetched": len(items),
|
||||
"pages_failed": len(errors),
|
||||
"errors": errors,
|
||||
},
|
||||
)
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
ticker: str,
|
||||
error: str,
|
||||
elapsed_ms: float,
|
||||
) -> AdapterResult:
|
||||
"""Build an error AdapterResult for scrape fetches."""
|
||||
return AdapterResult(
|
||||
source_type="web_scrape",
|
||||
ticker=ticker,
|
||||
items=[],
|
||||
raw_payload=b"",
|
||||
content_hash="",
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
error=error,
|
||||
http_status=None,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={"provider": "web_scrape"},
|
||||
)
|
||||
@@ -0,0 +1,169 @@
|
||||
"""Contradiction detection and disagreement representation.
|
||||
|
||||
Analyses weighted signals to detect and represent disagreement explicitly,
|
||||
rather than collapsing contradictory evidence into a single unsupported
|
||||
conclusion.
|
||||
|
||||
Requirements: 6.4, 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.aggregation.scoring import WeightedSignal
|
||||
from services.shared.schemas import DisagreementDetail
|
||||
|
||||
|
||||
@dataclass
|
||||
class CatalystEntry:
|
||||
"""Lightweight carrier for per-document catalyst info needed by
|
||||
contradiction detection. Avoids importing ImpactRow and creating
|
||||
a circular dependency with worker.py."""
|
||||
|
||||
document_id: str
|
||||
catalyst_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContradictionResult:
|
||||
"""Full contradiction analysis output."""
|
||||
|
||||
score: float # 0-1, same semantics as existing compute_contradiction_score
|
||||
details: list[DisagreementDetail]
|
||||
|
||||
|
||||
def detect_contradictions(
|
||||
signals: list[WeightedSignal],
|
||||
catalyst_entries: list[CatalystEntry] | None = None,
|
||||
) -> ContradictionResult:
|
||||
"""Run contradiction detection across multiple dimensions.
|
||||
|
||||
Analyses:
|
||||
1. Sentiment disagreement — the core positive-vs-negative split
|
||||
2. Catalyst disagreement — same catalyst type with opposing sentiment
|
||||
|
||||
Returns a ContradictionResult with an overall score and per-dimension
|
||||
disagreement details.
|
||||
"""
|
||||
details: list[DisagreementDetail] = []
|
||||
|
||||
sentiment_detail = _detect_sentiment_disagreement(signals)
|
||||
if sentiment_detail is not None:
|
||||
details.append(sentiment_detail)
|
||||
|
||||
if catalyst_entries:
|
||||
catalyst_details = _detect_catalyst_disagreement(signals, catalyst_entries)
|
||||
details.extend(catalyst_details)
|
||||
|
||||
score = _compute_overall_score(signals)
|
||||
|
||||
return ContradictionResult(score=score, details=details)
|
||||
|
||||
|
||||
def _compute_overall_score(signals: list[WeightedSignal]) -> float:
|
||||
"""Minority/majority weight ratio — backward-compatible formula."""
|
||||
if not signals:
|
||||
return 0.0
|
||||
|
||||
pos_weight = 0.0
|
||||
neg_weight = 0.0
|
||||
for sig in signals:
|
||||
w = sig.weight.combined * sig.impact_score
|
||||
if sig.sentiment_value > 0:
|
||||
pos_weight += w
|
||||
elif sig.sentiment_value < 0:
|
||||
neg_weight += w
|
||||
|
||||
total = pos_weight + neg_weight
|
||||
if total == 0.0:
|
||||
return 0.0
|
||||
|
||||
minority = min(pos_weight, neg_weight)
|
||||
return round(minority / total, 4)
|
||||
|
||||
|
||||
def _detect_sentiment_disagreement(
|
||||
signals: list[WeightedSignal],
|
||||
) -> DisagreementDetail | None:
|
||||
"""Detect when both positive and negative sentiment signals exist."""
|
||||
pos_ids: list[str] = []
|
||||
neg_ids: list[str] = []
|
||||
pos_weight = 0.0
|
||||
neg_weight = 0.0
|
||||
|
||||
for sig in signals:
|
||||
w = sig.weight.combined * sig.impact_score
|
||||
if w <= 0:
|
||||
continue
|
||||
if sig.sentiment_value > 0:
|
||||
pos_ids.append(sig.document_id)
|
||||
pos_weight += w
|
||||
elif sig.sentiment_value < 0:
|
||||
neg_ids.append(sig.document_id)
|
||||
neg_weight += w
|
||||
|
||||
if not pos_ids or not neg_ids:
|
||||
return None
|
||||
|
||||
total = pos_weight + neg_weight
|
||||
minority_pct = min(pos_weight, neg_weight) / total if total > 0 else 0.0
|
||||
|
||||
return DisagreementDetail(
|
||||
dimension="sentiment",
|
||||
positive_doc_ids=pos_ids,
|
||||
negative_doc_ids=neg_ids,
|
||||
positive_weight=round(pos_weight, 4),
|
||||
negative_weight=round(neg_weight, 4),
|
||||
description=(
|
||||
f"Sentiment split: {len(pos_ids)} positive vs {len(neg_ids)} negative signals "
|
||||
f"(minority weight ratio {minority_pct:.0%})"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _detect_catalyst_disagreement(
|
||||
signals: list[WeightedSignal],
|
||||
catalyst_entries: list[CatalystEntry],
|
||||
) -> list[DisagreementDetail]:
|
||||
"""Detect when the same catalyst type has both positive and negative signals."""
|
||||
# Build lookup: document_id → (sentiment_value, combined_weight)
|
||||
sig_lookup: dict[str, tuple[float, float]] = {}
|
||||
for sig in signals:
|
||||
w = sig.weight.combined * sig.impact_score
|
||||
if w > 0:
|
||||
sig_lookup[sig.document_id] = (sig.sentiment_value, w)
|
||||
|
||||
# Group by catalyst type
|
||||
from collections import defaultdict
|
||||
catalyst_groups: dict[str, list[tuple[str, float, float]]] = defaultdict(list)
|
||||
for entry in catalyst_entries:
|
||||
if entry.document_id in sig_lookup:
|
||||
sent_val, weight = sig_lookup[entry.document_id]
|
||||
if sent_val != 0.0:
|
||||
catalyst_groups[entry.catalyst_type].append(
|
||||
(entry.document_id, sent_val, weight)
|
||||
)
|
||||
|
||||
details: list[DisagreementDetail] = []
|
||||
for catalyst, entries in catalyst_groups.items():
|
||||
pos_ids = [doc_id for doc_id, sv, _ in entries if sv > 0]
|
||||
neg_ids = [doc_id for doc_id, sv, _ in entries if sv < 0]
|
||||
if not pos_ids or not neg_ids:
|
||||
continue
|
||||
|
||||
pos_w = sum(w for _, sv, w in entries if sv > 0)
|
||||
neg_w = sum(w for _, sv, w in entries if sv < 0)
|
||||
|
||||
details.append(DisagreementDetail(
|
||||
dimension=f"catalyst:{catalyst}",
|
||||
positive_doc_ids=pos_ids,
|
||||
negative_doc_ids=neg_ids,
|
||||
positive_weight=round(pos_w, 4),
|
||||
negative_weight=round(neg_w, 4),
|
||||
description=(
|
||||
f"Catalyst '{catalyst}' has {len(pos_ids)} positive and "
|
||||
f"{len(neg_ids)} negative signals"
|
||||
),
|
||||
))
|
||||
|
||||
return details
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Evidence ranking for supporting and opposing documents.
|
||||
|
||||
Ranks document signals by a composite score that considers multiple
|
||||
factors beyond raw weight, producing explainable evidence lists for
|
||||
trend summaries.
|
||||
|
||||
Requirements: 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from services.aggregation.scoring import WeightedSignal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EvidenceRankConfig:
|
||||
"""Weights for the composite evidence ranking score."""
|
||||
|
||||
# How much the combined signal weight matters (recency * credibility * novelty * market)
|
||||
weight_factor: float = 0.40
|
||||
# How much the document's impact score matters
|
||||
impact_factor: float = 0.30
|
||||
# How much recency alone matters (favours fresh evidence in the ranking)
|
||||
recency_factor: float = 0.20
|
||||
# How much extraction confidence matters
|
||||
confidence_factor: float = 0.10
|
||||
# Maximum evidence refs per side (supporting / opposing)
|
||||
max_refs: int = 10
|
||||
|
||||
|
||||
DEFAULT_RANK_CONFIG = EvidenceRankConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankedEvidence:
|
||||
"""A document with its composite ranking score and breakdown."""
|
||||
|
||||
document_id: str
|
||||
rank_score: float
|
||||
weight_component: float
|
||||
impact_component: float
|
||||
recency_component: float
|
||||
confidence_component: float
|
||||
sentiment_value: float # +1 / -1 / 0
|
||||
|
||||
|
||||
def compute_evidence_rank(
|
||||
signal: WeightedSignal,
|
||||
config: EvidenceRankConfig = DEFAULT_RANK_CONFIG,
|
||||
) -> RankedEvidence:
|
||||
"""Compute a composite ranking score for a single signal.
|
||||
|
||||
The score blends:
|
||||
- combined signal weight (captures recency decay, credibility, novelty, market ctx)
|
||||
- raw impact score
|
||||
- recency weight alone (extra boost for freshness in the ranking)
|
||||
- extraction confidence (via the credibility component of the weight)
|
||||
|
||||
All components are in [0, 1] so the composite is bounded by the sum
|
||||
of the factor weights.
|
||||
"""
|
||||
w = signal.weight
|
||||
|
||||
weight_component = w.combined * config.weight_factor
|
||||
impact_component = signal.impact_score * config.impact_factor
|
||||
recency_component = w.recency * config.recency_factor
|
||||
confidence_component = w.credibility * config.confidence_factor
|
||||
|
||||
rank_score = weight_component + impact_component + recency_component + confidence_component
|
||||
|
||||
return RankedEvidence(
|
||||
document_id=signal.document_id,
|
||||
rank_score=round(rank_score, 6),
|
||||
weight_component=round(weight_component, 6),
|
||||
impact_component=round(impact_component, 6),
|
||||
recency_component=round(recency_component, 6),
|
||||
confidence_component=round(confidence_component, 6),
|
||||
sentiment_value=signal.sentiment_value,
|
||||
)
|
||||
|
||||
|
||||
def rank_evidence(
|
||||
signals: list[WeightedSignal],
|
||||
config: EvidenceRankConfig = DEFAULT_RANK_CONFIG,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Rank signals into top supporting and opposing document ID lists.
|
||||
|
||||
Supporting = positive sentiment, Opposing = negative sentiment.
|
||||
Neutral/mixed signals are excluded.
|
||||
|
||||
Returns (supporting_ids, opposing_ids) each capped at config.max_refs.
|
||||
"""
|
||||
supporting: list[RankedEvidence] = []
|
||||
opposing: list[RankedEvidence] = []
|
||||
|
||||
for sig in signals:
|
||||
if sig.sentiment_value == 0.0:
|
||||
continue
|
||||
ranked = compute_evidence_rank(sig, config)
|
||||
if sig.sentiment_value > 0:
|
||||
supporting.append(ranked)
|
||||
else:
|
||||
opposing.append(ranked)
|
||||
|
||||
supporting.sort(key=lambda r: r.rank_score, reverse=True)
|
||||
opposing.sort(key=lambda r: r.rank_score, reverse=True)
|
||||
|
||||
return (
|
||||
[r.document_id for r in supporting[: config.max_refs]],
|
||||
[r.document_id for r in opposing[: config.max_refs]],
|
||||
)
|
||||
|
||||
|
||||
def rank_evidence_detailed(
|
||||
signals: list[WeightedSignal],
|
||||
config: EvidenceRankConfig = DEFAULT_RANK_CONFIG,
|
||||
) -> tuple[list[RankedEvidence], list[RankedEvidence]]:
|
||||
"""Like rank_evidence but returns full RankedEvidence objects.
|
||||
|
||||
Useful when callers need the score breakdown for explainability.
|
||||
"""
|
||||
supporting: list[RankedEvidence] = []
|
||||
opposing: list[RankedEvidence] = []
|
||||
|
||||
for sig in signals:
|
||||
if sig.sentiment_value == 0.0:
|
||||
continue
|
||||
ranked = compute_evidence_rank(sig, config)
|
||||
if sig.sentiment_value > 0:
|
||||
supporting.append(ranked)
|
||||
else:
|
||||
opposing.append(ranked)
|
||||
|
||||
supporting.sort(key=lambda r: r.rank_score, reverse=True)
|
||||
opposing.sort(key=lambda r: r.rank_score, reverse=True)
|
||||
|
||||
return (
|
||||
supporting[: config.max_refs],
|
||||
opposing[: config.max_refs],
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Aggregation worker entrypoint - polls Redis for aggregation jobs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.aggregation.worker import aggregate_company
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import QUEUE_AGGREGATION, queue_key
|
||||
|
||||
logger = logging.getLogger("aggregation_main")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_AGGREGATION)
|
||||
logger.info("Aggregation worker started, polling %s", queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
payload = raw
|
||||
job = json.loads(payload)
|
||||
ticker = job.get("ticker", "")
|
||||
|
||||
logger.info("Processing aggregation job for %s", ticker)
|
||||
|
||||
try:
|
||||
summaries = await aggregate_company(pool, ticker)
|
||||
logger.info(
|
||||
"Aggregation complete for %s: %d windows",
|
||||
ticker, len(summaries),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Aggregation failed for %s", ticker)
|
||||
finally:
|
||||
await pool.close()
|
||||
await redis_client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Market context feature computation for aggregation windows.
|
||||
|
||||
Fetches recent market snapshots from PostgreSQL and computes context
|
||||
features (price change, volume trend, volatility) that enrich trend
|
||||
summaries and modulate signal weighting.
|
||||
|
||||
Requirements: 6.1, 6.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.schemas import MarketContext, TrendWindow
|
||||
|
||||
# Map TrendWindow values to lookback durations in days.
|
||||
WINDOW_LOOKBACK_DAYS: dict[str, int] = {
|
||||
TrendWindow.INTRADAY.value: 1,
|
||||
TrendWindow.ONE_DAY.value: 2,
|
||||
TrendWindow.SEVEN_DAY.value: 8,
|
||||
TrendWindow.THIRTY_DAY.value: 35,
|
||||
TrendWindow.NINETY_DAY.value: 95,
|
||||
}
|
||||
|
||||
|
||||
async def fetch_market_context(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
) -> MarketContext:
|
||||
"""Build a MarketContext for *ticker* over the given trend *window*.
|
||||
|
||||
Queries the ``market_snapshots`` table for recent bars and computes:
|
||||
- price_change_pct: (last_close - first_close) / first_close
|
||||
- avg_volume: mean volume across bars
|
||||
- volume_change_pct: second-half avg volume vs first-half avg volume
|
||||
- volatility: std-dev of close prices
|
||||
- latest_close / latest_bar_at
|
||||
|
||||
Returns a MarketContext with ``bars_available == 0`` when no data exists.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
lookback_days = WINDOW_LOOKBACK_DAYS.get(window, 8)
|
||||
start = reference_time - timedelta(days=lookback_days)
|
||||
|
||||
rows = await pool.fetch(
|
||||
"""
|
||||
SELECT data, captured_at
|
||||
FROM market_snapshots
|
||||
WHERE ticker = $1
|
||||
AND captured_at >= $2
|
||||
AND captured_at <= $3
|
||||
ORDER BY captured_at ASC
|
||||
""",
|
||||
ticker,
|
||||
start,
|
||||
reference_time,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return MarketContext(ticker=ticker)
|
||||
|
||||
bars = _extract_bars(rows)
|
||||
if not bars:
|
||||
return MarketContext(ticker=ticker)
|
||||
|
||||
return _compute_context(ticker, bars)
|
||||
|
||||
|
||||
def _extract_bars(rows: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Extract OHLCV bar dicts from market_snapshot rows.
|
||||
|
||||
The ``data`` column is JSONB. Polygon prev-day bars store fields like
|
||||
``o``, ``h``, ``l``, ``c``, ``v``, ``t``. We normalise to a common
|
||||
dict with ``close``, ``volume``, ``captured_at``.
|
||||
"""
|
||||
bars: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
data = row["data"]
|
||||
if isinstance(data, str):
|
||||
import json
|
||||
data = json.loads(data)
|
||||
|
||||
# Polygon-style single bar or list of bars
|
||||
items = data if isinstance(data, list) else [data]
|
||||
for item in items:
|
||||
close = item.get("c") or item.get("close")
|
||||
volume = item.get("v") or item.get("volume")
|
||||
if close is not None:
|
||||
bars.append({
|
||||
"close": float(close),
|
||||
"volume": float(volume) if volume is not None else 0.0,
|
||||
"captured_at": row["captured_at"],
|
||||
})
|
||||
return bars
|
||||
|
||||
|
||||
def _compute_context(ticker: str, bars: list[dict[str, Any]]) -> MarketContext:
|
||||
"""Derive market context features from a sorted list of bar dicts."""
|
||||
closes = [b["close"] for b in bars]
|
||||
volumes = [b["volume"] for b in bars]
|
||||
|
||||
first_close = closes[0]
|
||||
last_close = closes[-1]
|
||||
|
||||
price_change_pct = (
|
||||
((last_close - first_close) / first_close * 100.0)
|
||||
if first_close != 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
avg_volume = sum(volumes) / len(volumes) if volumes else 0.0
|
||||
|
||||
# Volume trend: compare second half to first half
|
||||
mid = len(volumes) // 2
|
||||
if mid > 0:
|
||||
first_half_avg = sum(volumes[:mid]) / mid
|
||||
second_half_avg = sum(volumes[mid:]) / len(volumes[mid:])
|
||||
volume_change_pct = (
|
||||
((second_half_avg - first_half_avg) / first_half_avg * 100.0)
|
||||
if first_half_avg > 0
|
||||
else 0.0
|
||||
)
|
||||
else:
|
||||
volume_change_pct = 0.0
|
||||
|
||||
# Volatility: std dev of closes
|
||||
if len(closes) > 1:
|
||||
mean_close = sum(closes) / len(closes)
|
||||
variance = sum((c - mean_close) ** 2 for c in closes) / len(closes)
|
||||
volatility = math.sqrt(variance)
|
||||
else:
|
||||
volatility = 0.0
|
||||
|
||||
return MarketContext(
|
||||
ticker=ticker,
|
||||
price_change_pct=round(price_change_pct, 4),
|
||||
avg_volume=round(avg_volume, 2),
|
||||
volume_change_pct=round(volume_change_pct, 4),
|
||||
volatility=round(volatility, 6),
|
||||
latest_close=last_close,
|
||||
latest_bar_at=bars[-1]["captured_at"],
|
||||
bars_available=len(bars),
|
||||
)
|
||||
@@ -0,0 +1,439 @@
|
||||
"""Sector and market-level rollup aggregation.
|
||||
|
||||
Aggregates company-level trend summaries into sector and market-level
|
||||
summaries, enabling top-down views of sentiment and risk across the
|
||||
portfolio.
|
||||
|
||||
Requirements: 6.3, 6.4, 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.schemas import (
|
||||
DisagreementDetail,
|
||||
TrendDirection,
|
||||
TrendSummary,
|
||||
TrendWindow,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompanyTrendRow:
|
||||
"""A company-level trend summary fetched from the DB for rollup."""
|
||||
|
||||
entity_id: str # ticker
|
||||
sector: str
|
||||
window: str
|
||||
trend_direction: str
|
||||
trend_strength: float
|
||||
confidence: float
|
||||
contradiction_score: float
|
||||
dominant_catalysts: list[str]
|
||||
material_risks: list[str]
|
||||
top_supporting_evidence: list[str]
|
||||
top_opposing_evidence: list[str]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch latest company trends for a given window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LATEST_COMPANY_TRENDS_QUERY = """
|
||||
SELECT DISTINCT ON (tw.entity_id)
|
||||
tw.entity_id,
|
||||
c.sector,
|
||||
tw.window,
|
||||
tw.trend_direction,
|
||||
tw.trend_strength,
|
||||
tw.confidence,
|
||||
tw.contradiction_score,
|
||||
tw.dominant_catalysts,
|
||||
tw.material_risks,
|
||||
tw.top_supporting_evidence,
|
||||
tw.top_opposing_evidence
|
||||
FROM trend_windows tw
|
||||
JOIN companies c ON c.ticker = tw.entity_id AND c.active = TRUE
|
||||
WHERE tw.entity_type = 'company'
|
||||
AND tw.window = $1
|
||||
AND tw.generated_at >= $2
|
||||
ORDER BY tw.entity_id, tw.generated_at DESC
|
||||
"""
|
||||
|
||||
|
||||
def _parse_jsonb_list(val: object) -> list[str]:
|
||||
"""Safely parse a JSONB column that should be a list of strings."""
|
||||
if isinstance(val, list):
|
||||
return [str(v) for v in val]
|
||||
if isinstance(val, str):
|
||||
parsed = json.loads(val)
|
||||
if isinstance(parsed, list):
|
||||
return [str(v) for v in parsed]
|
||||
return []
|
||||
|
||||
|
||||
def _parse_company_trend_row(row: object) -> CompanyTrendRow:
|
||||
"""Convert an asyncpg Record to a CompanyTrendRow."""
|
||||
# asyncpg Records support dict() but aren't typed; use getattr-style access
|
||||
get = getattr(row, "__getitem__", None)
|
||||
if get is None:
|
||||
raise TypeError(f"Expected a mapping-like row, got {type(row)}")
|
||||
|
||||
def _str(key: str, default: str = "") -> str:
|
||||
val = get(key)
|
||||
return str(val) if val is not None else default
|
||||
|
||||
def _float(key: str) -> float:
|
||||
val = get(key)
|
||||
return float(val) if val is not None else 0.0
|
||||
|
||||
return CompanyTrendRow(
|
||||
entity_id=_str("entity_id"),
|
||||
sector=_str("sector", "Unknown") or "Unknown",
|
||||
window=_str("window"),
|
||||
trend_direction=_str("trend_direction"),
|
||||
trend_strength=_float("trend_strength"),
|
||||
confidence=_float("confidence"),
|
||||
contradiction_score=_float("contradiction_score"),
|
||||
dominant_catalysts=_parse_jsonb_list(get("dominant_catalysts")),
|
||||
material_risks=_parse_jsonb_list(get("material_risks")),
|
||||
top_supporting_evidence=_parse_jsonb_list(get("top_supporting_evidence")),
|
||||
top_opposing_evidence=_parse_jsonb_list(get("top_opposing_evidence")),
|
||||
)
|
||||
|
||||
|
||||
async def fetch_latest_company_trends(
|
||||
pool: asyncpg.Pool,
|
||||
window: str,
|
||||
since: datetime,
|
||||
) -> list[CompanyTrendRow]:
|
||||
"""Fetch the most recent company-level trend for each ticker in a window."""
|
||||
rows = await pool.fetch(_LATEST_COMPANY_TRENDS_QUERY, window, since)
|
||||
return [_parse_company_trend_row(r) for r in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure rollup logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Direction mapping for numeric aggregation
|
||||
_DIRECTION_VALUES = {
|
||||
TrendDirection.BULLISH.value: 1.0,
|
||||
TrendDirection.BEARISH.value: -1.0,
|
||||
TrendDirection.MIXED.value: 0.0,
|
||||
TrendDirection.NEUTRAL.value: 0.0,
|
||||
}
|
||||
|
||||
BULLISH_THRESHOLD = 0.15
|
||||
BEARISH_THRESHOLD = -0.15
|
||||
|
||||
|
||||
def rollup_trends(
|
||||
trends: list[CompanyTrendRow],
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
window: str,
|
||||
reference_time: datetime,
|
||||
) -> TrendSummary:
|
||||
"""Aggregate a list of company-level trends into a single rollup summary.
|
||||
|
||||
Each company trend is weighted by its confidence to produce a
|
||||
confidence-weighted average of direction, strength, and contradiction.
|
||||
"""
|
||||
if not trends:
|
||||
return TrendSummary(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
window=TrendWindow(window),
|
||||
trend_direction=TrendDirection.NEUTRAL,
|
||||
trend_strength=0.0,
|
||||
confidence=0.0,
|
||||
generated_at=reference_time,
|
||||
)
|
||||
|
||||
total_weight = 0.0
|
||||
weighted_direction = 0.0
|
||||
weighted_strength = 0.0
|
||||
weighted_contradiction = 0.0
|
||||
catalyst_weights: dict[str, float] = {}
|
||||
risk_set: dict[str, float] = {}
|
||||
all_supporting: list[str] = []
|
||||
all_opposing: list[str] = []
|
||||
|
||||
for t in trends:
|
||||
w = t.confidence
|
||||
total_weight += w
|
||||
dir_val = _DIRECTION_VALUES.get(t.trend_direction, 0.0)
|
||||
weighted_direction += w * dir_val
|
||||
weighted_strength += w * t.trend_strength
|
||||
weighted_contradiction += w * t.contradiction_score
|
||||
|
||||
for cat in t.dominant_catalysts:
|
||||
catalyst_weights[cat] = catalyst_weights.get(cat, 0.0) + w
|
||||
|
||||
for risk in t.material_risks:
|
||||
norm = risk.strip().lower()
|
||||
if norm not in risk_set:
|
||||
risk_set[norm] = w
|
||||
else:
|
||||
risk_set[norm] = max(risk_set[norm], w)
|
||||
|
||||
all_supporting.extend(t.top_supporting_evidence)
|
||||
all_opposing.extend(t.top_opposing_evidence)
|
||||
|
||||
if total_weight == 0.0:
|
||||
return TrendSummary(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
window=TrendWindow(window),
|
||||
trend_direction=TrendDirection.NEUTRAL,
|
||||
trend_strength=0.0,
|
||||
confidence=0.0,
|
||||
generated_at=reference_time,
|
||||
)
|
||||
|
||||
avg_direction = weighted_direction / total_weight
|
||||
avg_strength = weighted_strength / total_weight
|
||||
avg_contradiction = weighted_contradiction / total_weight
|
||||
avg_confidence = total_weight / len(trends)
|
||||
|
||||
# Derive direction
|
||||
direction = _derive_rollup_direction(avg_direction, avg_contradiction)
|
||||
|
||||
# Top catalysts
|
||||
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
|
||||
catalysts = [c for c, _ in sorted_catalysts[:5]]
|
||||
|
||||
# Top risks (deduplicated, by weight)
|
||||
sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True)
|
||||
risks = [r for r, _ in sorted_risks[:5]]
|
||||
|
||||
# Disagreement details
|
||||
disagreement = _build_rollup_disagreement(trends, entity_id)
|
||||
|
||||
return TrendSummary(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
window=TrendWindow(window),
|
||||
trend_direction=direction,
|
||||
trend_strength=round(min(abs(avg_strength), 1.0), 4),
|
||||
confidence=round(max(0.0, min(avg_confidence, 1.0)), 4),
|
||||
top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10],
|
||||
top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10],
|
||||
dominant_catalysts=catalysts,
|
||||
material_risks=risks,
|
||||
contradiction_score=round(max(0.0, min(avg_contradiction, 1.0)), 4),
|
||||
disagreement_details=disagreement,
|
||||
generated_at=reference_time,
|
||||
)
|
||||
|
||||
|
||||
def _derive_rollup_direction(
|
||||
avg_direction: float,
|
||||
avg_contradiction: float,
|
||||
) -> TrendDirection:
|
||||
"""Map averaged direction value to a TrendDirection."""
|
||||
if avg_contradiction > 0.10 and abs(avg_direction) < 0.3:
|
||||
return TrendDirection.MIXED
|
||||
if avg_direction >= BULLISH_THRESHOLD:
|
||||
return TrendDirection.BULLISH
|
||||
if avg_direction <= BEARISH_THRESHOLD:
|
||||
return TrendDirection.BEARISH
|
||||
return TrendDirection.NEUTRAL
|
||||
|
||||
|
||||
def _build_rollup_disagreement(
|
||||
trends: list[CompanyTrendRow],
|
||||
entity_id: str,
|
||||
) -> list[DisagreementDetail]:
|
||||
"""Build disagreement details showing which companies are bullish vs bearish."""
|
||||
bullish_ids: list[str] = []
|
||||
bearish_ids: list[str] = []
|
||||
bullish_weight = 0.0
|
||||
bearish_weight = 0.0
|
||||
|
||||
for t in trends:
|
||||
if t.trend_direction == TrendDirection.BULLISH.value:
|
||||
bullish_ids.append(t.entity_id)
|
||||
bullish_weight += t.confidence
|
||||
elif t.trend_direction == TrendDirection.BEARISH.value:
|
||||
bearish_ids.append(t.entity_id)
|
||||
bearish_weight += t.confidence
|
||||
|
||||
if not bullish_ids or not bearish_ids:
|
||||
return []
|
||||
|
||||
return [
|
||||
DisagreementDetail(
|
||||
dimension="company_direction",
|
||||
positive_doc_ids=bullish_ids,
|
||||
negative_doc_ids=bearish_ids,
|
||||
positive_weight=round(bullish_weight, 4),
|
||||
negative_weight=round(bearish_weight, 4),
|
||||
description=(
|
||||
f"{entity_id}: {len(bullish_ids)} bullish vs "
|
||||
f"{len(bearish_ids)} bearish companies"
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist rollup (reuses the same trend_windows table)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UPSERT_TREND = """
|
||||
INSERT INTO trend_windows (
|
||||
entity_type, entity_id, window, trend_direction, trend_strength,
|
||||
confidence, top_supporting_evidence, top_opposing_evidence,
|
||||
dominant_catalysts, material_risks, contradiction_score,
|
||||
disagreement_details, market_context, generated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7::jsonb, $8::jsonb,
|
||||
$9::jsonb, $10::jsonb, $11,
|
||||
$12::jsonb, $13::jsonb, $14
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
|
||||
async def persist_rollup(
|
||||
pool: asyncpg.Pool,
|
||||
summary: TrendSummary,
|
||||
) -> str:
|
||||
"""Insert a rollup trend summary and return its UUID."""
|
||||
row = await pool.fetchrow(
|
||||
_UPSERT_TREND,
|
||||
summary.entity_type,
|
||||
summary.entity_id,
|
||||
summary.window.value,
|
||||
summary.trend_direction.value,
|
||||
summary.trend_strength,
|
||||
summary.confidence,
|
||||
json.dumps(summary.top_supporting_evidence),
|
||||
json.dumps(summary.top_opposing_evidence),
|
||||
json.dumps(summary.dominant_catalysts),
|
||||
json.dumps(summary.material_risks),
|
||||
summary.contradiction_score,
|
||||
json.dumps([d.model_dump() for d in summary.disagreement_details]),
|
||||
json.dumps({}),
|
||||
summary.generated_at,
|
||||
)
|
||||
return str(row["id"]) # type: ignore[index]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level rollup entry points
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def aggregate_sector(
|
||||
pool: asyncpg.Pool,
|
||||
sector: str,
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a sector-level rollup for one window.
|
||||
|
||||
Fetches the latest company trends, filters to the given sector,
|
||||
and rolls them up into a single sector summary.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
if since is None:
|
||||
since = reference_time - _window_lookback(window)
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
sector_trends = [t for t in all_trends if t.sector == sector]
|
||||
|
||||
summary = rollup_trends(sector_trends, "sector", sector, window, reference_time)
|
||||
|
||||
if sector_trends:
|
||||
rollup_id = await persist_rollup(pool, summary)
|
||||
logger.info(
|
||||
"Persisted sector rollup %s for %s/%s: direction=%s strength=%.3f companies=%d",
|
||||
rollup_id, sector, window, summary.trend_direction.value,
|
||||
summary.trend_strength, len(sector_trends),
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def aggregate_market(
|
||||
pool: asyncpg.Pool,
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a market-wide rollup for one window.
|
||||
|
||||
Aggregates all company trends regardless of sector.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
if since is None:
|
||||
since = reference_time - _window_lookback(window)
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
|
||||
summary = rollup_trends(all_trends, "market", "all", window, reference_time)
|
||||
|
||||
if all_trends:
|
||||
rollup_id = await persist_rollup(pool, summary)
|
||||
logger.info(
|
||||
"Persisted market rollup %s for %s: direction=%s strength=%.3f companies=%d",
|
||||
rollup_id, window, summary.trend_direction.value,
|
||||
summary.trend_strength, len(all_trends),
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def aggregate_all_sectors(
|
||||
pool: asyncpg.Pool,
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[TrendSummary]:
|
||||
"""Compute sector rollups for every sector that has company trends."""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
if since is None:
|
||||
since = reference_time - _window_lookback(window)
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
|
||||
# Group by sector
|
||||
sectors: dict[str, list[CompanyTrendRow]] = {}
|
||||
for t in all_trends:
|
||||
sectors.setdefault(t.sector, []).append(t)
|
||||
|
||||
summaries: list[TrendSummary] = []
|
||||
for sector, trends in sectors.items():
|
||||
summary = rollup_trends(trends, "sector", sector, window, reference_time)
|
||||
if trends:
|
||||
_id = await persist_rollup(pool, summary)
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def _window_lookback(window: str) -> timedelta:
|
||||
"""Return a reasonable lookback for finding recent company trends."""
|
||||
mapping = {
|
||||
TrendWindow.INTRADAY.value: timedelta(hours=24),
|
||||
TrendWindow.ONE_DAY.value: timedelta(days=2),
|
||||
TrendWindow.SEVEN_DAY.value: timedelta(days=8),
|
||||
TrendWindow.THIRTY_DAY.value: timedelta(days=35),
|
||||
TrendWindow.NINETY_DAY.value: timedelta(days=95),
|
||||
}
|
||||
return mapping.get(window, timedelta(days=8))
|
||||
@@ -0,0 +1,285 @@
|
||||
"""Recency decay, source credibility weighting, and market context
|
||||
integration for aggregation.
|
||||
|
||||
Provides scoring functions used by the aggregation engine to weight
|
||||
document intelligence signals when computing trend summaries.
|
||||
|
||||
Requirements: 6.1, 6.2, 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from services.shared.schemas import MarketContext
|
||||
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScoringConfig:
|
||||
"""Tunable parameters for signal scoring."""
|
||||
|
||||
# Recency decay: exponential half-life in hours per window.
|
||||
# After one half-life, a document's recency weight drops to 0.5.
|
||||
half_life_hours: dict[str, float] = field(default_factory=lambda: {
|
||||
"intraday": 2.0,
|
||||
"1d": 12.0,
|
||||
"7d": 72.0,
|
||||
"30d": 240.0,
|
||||
"90d": 720.0,
|
||||
})
|
||||
|
||||
# Minimum recency weight — prevents very old docs from being zeroed out
|
||||
# entirely so they can still contribute trace-level signal.
|
||||
min_recency_weight: float = 0.01
|
||||
|
||||
# Source credibility bounds — credibility scores outside this range
|
||||
# are clamped before weighting.
|
||||
credibility_floor: float = 0.1
|
||||
credibility_ceiling: float = 1.0
|
||||
|
||||
# Exponent applied to credibility score. >1 penalises low-credibility
|
||||
# sources more aggressively; <1 flattens the curve.
|
||||
credibility_exponent: float = 1.0
|
||||
|
||||
# Novelty bonus: multiplier range applied on top of base weight.
|
||||
# A novelty_score of 1.0 gets the full bonus; 0.0 gets none.
|
||||
novelty_bonus_max: float = 0.25
|
||||
|
||||
# Confidence floor — documents below this extraction confidence
|
||||
# receive zero weight (they are too unreliable to aggregate).
|
||||
confidence_floor: float = 0.2
|
||||
|
||||
# Market context modulation ---
|
||||
# When volatility exceeds this threshold (in price units), recency
|
||||
# signals are amplified because fast-moving markets make fresh data
|
||||
# more important.
|
||||
volatility_recency_boost_threshold: float = 1.0
|
||||
volatility_recency_boost_max: float = 0.30 # max extra multiplier
|
||||
|
||||
# When volume surges above this % change, signals get a small boost
|
||||
# because high-volume moves carry more conviction.
|
||||
volume_surge_threshold_pct: float = 50.0
|
||||
volume_surge_boost: float = 0.15
|
||||
|
||||
|
||||
# Singleton default config
|
||||
DEFAULT_CONFIG = ScoringConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recency decay
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def recency_weight(
|
||||
published_at: datetime,
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
config: ScoringConfig = DEFAULT_CONFIG,
|
||||
) -> float:
|
||||
"""Compute an exponential recency decay weight for a document.
|
||||
|
||||
Uses the formula: w = 2^(-age_hours / half_life)
|
||||
|
||||
Args:
|
||||
published_at: When the document was published (tz-aware).
|
||||
reference_time: The "now" anchor for the aggregation window (tz-aware).
|
||||
window: One of the TrendWindow values (e.g. "7d").
|
||||
config: Scoring parameters.
|
||||
|
||||
Returns:
|
||||
A weight in [config.min_recency_weight, 1.0].
|
||||
"""
|
||||
# Ensure both are tz-aware; treat naive as UTC.
|
||||
if published_at.tzinfo is None:
|
||||
published_at = published_at.replace(tzinfo=timezone.utc)
|
||||
if reference_time.tzinfo is None:
|
||||
reference_time = reference_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
age_seconds = (reference_time - published_at).total_seconds()
|
||||
if age_seconds <= 0:
|
||||
return 1.0
|
||||
|
||||
age_hours = age_seconds / 3600.0
|
||||
half_life = config.half_life_hours.get(window, 72.0)
|
||||
|
||||
weight = math.pow(2.0, -age_hours / half_life)
|
||||
return max(weight, config.min_recency_weight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source credibility weighting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def credibility_weight(
|
||||
source_credibility: float,
|
||||
config: ScoringConfig = DEFAULT_CONFIG,
|
||||
) -> float:
|
||||
"""Compute a weight from a source's credibility score.
|
||||
|
||||
The raw credibility (0-1) is clamped to [floor, ceiling] then raised
|
||||
to ``credibility_exponent``.
|
||||
|
||||
Args:
|
||||
source_credibility: The credibility score from the source or
|
||||
document intelligence record (0-1).
|
||||
config: Scoring parameters.
|
||||
|
||||
Returns:
|
||||
A weight in [floor^exp, ceiling^exp].
|
||||
"""
|
||||
clamped = max(config.credibility_floor, min(source_credibility, config.credibility_ceiling))
|
||||
return math.pow(clamped, config.credibility_exponent)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Market context adjustment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def market_context_multiplier(
|
||||
market_ctx: MarketContext | None,
|
||||
config: ScoringConfig = DEFAULT_CONFIG,
|
||||
) -> float:
|
||||
"""Compute a multiplicative adjustment from market context features.
|
||||
|
||||
Returns a value >= 1.0 that amplifies signal weights when market
|
||||
conditions suggest heightened importance (high volatility or volume
|
||||
surges). Returns 1.0 when no market context is available.
|
||||
"""
|
||||
if market_ctx is None or not market_ctx.has_data:
|
||||
return 1.0
|
||||
|
||||
boost = 0.0
|
||||
|
||||
# Volatility boost — more volatile markets make recent signals more valuable
|
||||
if market_ctx.volatility is not None and market_ctx.volatility > config.volatility_recency_boost_threshold:
|
||||
excess = market_ctx.volatility - config.volatility_recency_boost_threshold
|
||||
# Logarithmic scaling so extreme volatility doesn't blow up the weight
|
||||
boost += min(
|
||||
math.log1p(excess) * 0.15,
|
||||
config.volatility_recency_boost_max,
|
||||
)
|
||||
|
||||
# Volume surge boost
|
||||
if market_ctx.volume_change_pct is not None and market_ctx.volume_change_pct > config.volume_surge_threshold_pct:
|
||||
boost += config.volume_surge_boost
|
||||
|
||||
return 1.0 + boost
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined document signal weight
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalWeight:
|
||||
"""Breakdown of a document's aggregation weight."""
|
||||
|
||||
recency: float
|
||||
credibility: float
|
||||
novelty_bonus: float
|
||||
confidence_gate: float # 0.0 or 1.0
|
||||
market_ctx_multiplier: float # >= 1.0
|
||||
combined: float
|
||||
|
||||
|
||||
def compute_signal_weight(
|
||||
published_at: datetime,
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
source_credibility: float,
|
||||
novelty_score: float = 0.5,
|
||||
extraction_confidence: float = 0.5,
|
||||
market_ctx: MarketContext | None = None,
|
||||
config: ScoringConfig = DEFAULT_CONFIG,
|
||||
) -> SignalWeight:
|
||||
"""Compute the combined aggregation weight for a single document signal.
|
||||
|
||||
The formula is:
|
||||
combined = confidence_gate * recency * credibility
|
||||
* (1 + novelty_bonus) * market_ctx_multiplier
|
||||
|
||||
where novelty_bonus = novelty_score * config.novelty_bonus_max
|
||||
and market_ctx_multiplier >= 1.0 based on volatility/volume features.
|
||||
|
||||
Documents with extraction_confidence below config.confidence_floor
|
||||
receive a combined weight of 0.0 (gated out).
|
||||
|
||||
Args:
|
||||
published_at: Document publication time.
|
||||
reference_time: Aggregation anchor time.
|
||||
window: Trend window identifier.
|
||||
source_credibility: Source credibility score (0-1).
|
||||
novelty_score: Document novelty score (0-1).
|
||||
extraction_confidence: Extraction confidence from the model (0-1).
|
||||
market_ctx: Optional market context features for the symbol.
|
||||
config: Scoring parameters.
|
||||
|
||||
Returns:
|
||||
A ``SignalWeight`` with the component breakdown and combined score.
|
||||
"""
|
||||
# Confidence gate
|
||||
gate = 1.0 if extraction_confidence >= config.confidence_floor else 0.0
|
||||
|
||||
rec = recency_weight(published_at, reference_time, window, config)
|
||||
cred = credibility_weight(source_credibility, config)
|
||||
bonus = novelty_score * config.novelty_bonus_max
|
||||
mkt_mult = market_context_multiplier(market_ctx, config)
|
||||
|
||||
combined = gate * rec * cred * (1.0 + bonus) * mkt_mult
|
||||
|
||||
return SignalWeight(
|
||||
recency=rec,
|
||||
credibility=cred,
|
||||
novelty_bonus=bonus,
|
||||
confidence_gate=gate,
|
||||
market_ctx_multiplier=mkt_mult,
|
||||
combined=combined,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightedSignal:
|
||||
"""A document intelligence reference paired with its computed weight."""
|
||||
|
||||
document_id: str
|
||||
weight: SignalWeight
|
||||
sentiment_value: float # numeric sentiment: +1 positive, -1 negative, 0 neutral/mixed
|
||||
impact_score: float
|
||||
|
||||
|
||||
def sentiment_to_numeric(sentiment: str) -> float:
|
||||
"""Map a sentiment label to a signed numeric value."""
|
||||
mapping = {
|
||||
"positive": 1.0,
|
||||
"negative": -1.0,
|
||||
"neutral": 0.0,
|
||||
"mixed": 0.0,
|
||||
}
|
||||
return mapping.get(sentiment.lower(), 0.0)
|
||||
|
||||
|
||||
def weighted_sentiment_average(signals: list[WeightedSignal]) -> float:
|
||||
"""Compute a weight-adjusted average sentiment across signals.
|
||||
|
||||
Returns a value in [-1, 1]. Returns 0.0 when total weight is zero.
|
||||
"""
|
||||
total_weight = 0.0
|
||||
weighted_sum = 0.0
|
||||
for sig in signals:
|
||||
w = sig.weight.combined * sig.impact_score
|
||||
weighted_sum += w * sig.sentiment_value
|
||||
total_weight += w
|
||||
if total_weight == 0.0:
|
||||
return 0.0
|
||||
return weighted_sum / total_weight
|
||||
@@ -1 +1,650 @@
|
||||
"""Aggregation worker - rolling trend summaries, contradiction detection, evidence ranking."""
|
||||
"""Aggregation worker - company-level rolling window trend summaries.
|
||||
|
||||
Queries document intelligence and market context for a given ticker,
|
||||
computes weighted signal scores, and produces TrendSummary objects
|
||||
persisted to the trend_windows table.
|
||||
|
||||
Requirements: 6.1, 6.2, 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.aggregation.contradiction import CatalystEntry, detect_contradictions
|
||||
from services.aggregation.evidence import (
|
||||
EvidenceRankConfig,
|
||||
RankedEvidence,
|
||||
rank_evidence as _rank_evidence_composite,
|
||||
rank_evidence_detailed,
|
||||
)
|
||||
from services.aggregation.market_context import fetch_market_context
|
||||
from services.aggregation.scoring import (
|
||||
ScoringConfig,
|
||||
WeightedSignal,
|
||||
compute_signal_weight,
|
||||
sentiment_to_numeric,
|
||||
weighted_sentiment_average,
|
||||
)
|
||||
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
|
||||
from services.shared.metrics import (
|
||||
AGGREGATION_CONTRADICTION_SCORE,
|
||||
AGGREGATION_DURATION,
|
||||
AGGREGATION_SIGNALS_PROCESSED,
|
||||
AGGREGATION_WINDOWS_COMPUTED,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map TrendWindow values to lookback durations.
|
||||
WINDOW_DURATIONS: dict[str, timedelta] = {
|
||||
TrendWindow.INTRADAY.value: timedelta(hours=12),
|
||||
TrendWindow.ONE_DAY.value: timedelta(days=1),
|
||||
TrendWindow.SEVEN_DAY.value: timedelta(days=7),
|
||||
TrendWindow.THIRTY_DAY.value: timedelta(days=30),
|
||||
TrendWindow.NINETY_DAY.value: timedelta(days=90),
|
||||
}
|
||||
|
||||
# How many evidence document IDs to keep in supporting/opposing lists.
|
||||
MAX_EVIDENCE_REFS = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregationConfig:
|
||||
"""Controls which windows to compute and scoring parameters."""
|
||||
|
||||
windows: list[str] | None = None # None = all windows
|
||||
scoring: ScoringConfig | None = None
|
||||
max_evidence: int = MAX_EVIDENCE_REFS
|
||||
|
||||
def effective_windows(self) -> list[str]:
|
||||
if self.windows:
|
||||
return self.windows
|
||||
return [w.value for w in TrendWindow]
|
||||
|
||||
def effective_scoring(self) -> ScoringConfig:
|
||||
return self.scoring or ScoringConfig()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch impact records for a ticker within a time window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPACT_QUERY = """
|
||||
SELECT
|
||||
di.document_id,
|
||||
di.confidence,
|
||||
di.novelty_score,
|
||||
di.source_credibility,
|
||||
dir.sentiment,
|
||||
dir.impact_score,
|
||||
dir.catalyst_type,
|
||||
dir.key_facts,
|
||||
dir.risks,
|
||||
d.published_at
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND d.published_at >= $2
|
||||
AND d.published_at <= $3
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
ORDER BY d.published_at DESC
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImpactRow:
|
||||
"""Parsed row from the impact query."""
|
||||
|
||||
document_id: str
|
||||
confidence: float
|
||||
novelty_score: float
|
||||
source_credibility: float
|
||||
sentiment: str
|
||||
impact_score: float
|
||||
catalyst_type: str
|
||||
key_facts: list[str]
|
||||
risks: list[str]
|
||||
published_at: datetime
|
||||
|
||||
|
||||
def _parse_impact_row(row: Any) -> ImpactRow:
|
||||
"""Convert an asyncpg Record to an ImpactRow."""
|
||||
key_facts = row["key_facts"]
|
||||
if isinstance(key_facts, str):
|
||||
key_facts = json.loads(key_facts)
|
||||
risks = row["risks"]
|
||||
if isinstance(risks, str):
|
||||
risks = json.loads(risks)
|
||||
|
||||
return ImpactRow(
|
||||
document_id=str(row["document_id"]),
|
||||
confidence=float(row["confidence"] or 0.5),
|
||||
novelty_score=float(row["novelty_score"] or 0.5),
|
||||
source_credibility=float(row["source_credibility"] or 0.5),
|
||||
sentiment=row["sentiment"] or "neutral",
|
||||
impact_score=float(row["impact_score"] or 0.0),
|
||||
catalyst_type=row["catalyst_type"] or "other",
|
||||
key_facts=key_facts if isinstance(key_facts, list) else [],
|
||||
risks=risks if isinstance(risks, list) else [],
|
||||
published_at=row["published_at"],
|
||||
)
|
||||
|
||||
|
||||
async def fetch_impact_records(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> list[ImpactRow]:
|
||||
"""Fetch validated document impact records for a ticker in a time range."""
|
||||
rows = await pool.fetch(_IMPACT_QUERY, ticker, window_start, window_end)
|
||||
return [_parse_impact_row(r) for r in rows]
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build weighted signals from impact records
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_weighted_signals(
|
||||
impacts: list[ImpactRow],
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
market_ctx: Any | None = None,
|
||||
config: ScoringConfig | None = None,
|
||||
) -> list[WeightedSignal]:
|
||||
"""Convert impact records into WeightedSignal objects using the scoring module."""
|
||||
cfg = config or ScoringConfig()
|
||||
signals: list[WeightedSignal] = []
|
||||
for imp in impacts:
|
||||
sw = compute_signal_weight(
|
||||
published_at=imp.published_at,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=imp.source_credibility,
|
||||
novelty_score=imp.novelty_score,
|
||||
extraction_confidence=imp.confidence,
|
||||
market_ctx=market_ctx,
|
||||
config=cfg,
|
||||
)
|
||||
signals.append(
|
||||
WeightedSignal(
|
||||
document_id=imp.document_id,
|
||||
weight=sw,
|
||||
sentiment_value=sentiment_to_numeric(imp.sentiment),
|
||||
impact_score=imp.impact_score,
|
||||
)
|
||||
)
|
||||
return signals
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Derive trend direction from weighted sentiment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Thresholds for mapping numeric sentiment to direction.
|
||||
BULLISH_THRESHOLD = 0.15
|
||||
BEARISH_THRESHOLD = -0.15
|
||||
MIXED_THRESHOLD = 0.10 # contradiction score above this → mixed
|
||||
|
||||
|
||||
def derive_trend_direction(
|
||||
avg_sentiment: float,
|
||||
contradiction_score: float = 0.0,
|
||||
) -> TrendDirection:
|
||||
"""Map a weighted average sentiment to a TrendDirection.
|
||||
|
||||
If contradiction is high, the direction is MIXED regardless of
|
||||
the average sentiment value.
|
||||
"""
|
||||
if contradiction_score > MIXED_THRESHOLD and abs(avg_sentiment) < 0.3:
|
||||
return TrendDirection.MIXED
|
||||
if avg_sentiment >= BULLISH_THRESHOLD:
|
||||
return TrendDirection.BULLISH
|
||||
if avg_sentiment <= BEARISH_THRESHOLD:
|
||||
return TrendDirection.BEARISH
|
||||
return TrendDirection.NEUTRAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compute contradiction score
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_contradiction_score(signals: list[WeightedSignal]) -> float:
|
||||
"""Measure how much disagreement exists among weighted signals.
|
||||
|
||||
Returns a value in [0, 1] where 0 means full agreement and 1 means
|
||||
equal-weight positive and negative signals.
|
||||
|
||||
The formula computes the ratio of the minority-side total weight to
|
||||
the majority-side total weight.
|
||||
"""
|
||||
if not signals:
|
||||
return 0.0
|
||||
|
||||
pos_weight = 0.0
|
||||
neg_weight = 0.0
|
||||
for sig in signals:
|
||||
w = sig.weight.combined * sig.impact_score
|
||||
if sig.sentiment_value > 0:
|
||||
pos_weight += w
|
||||
elif sig.sentiment_value < 0:
|
||||
neg_weight += w
|
||||
|
||||
total = pos_weight + neg_weight
|
||||
if total == 0.0:
|
||||
return 0.0
|
||||
|
||||
minority = min(pos_weight, neg_weight)
|
||||
return round(minority / total, 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rank evidence (supporting vs opposing)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def rank_evidence(
|
||||
signals: list[WeightedSignal],
|
||||
max_refs: int = MAX_EVIDENCE_REFS,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Return top supporting and opposing document IDs ranked by composite score.
|
||||
|
||||
Delegates to the evidence ranking module which considers multiple
|
||||
factors (weight, impact, recency, confidence) rather than raw weight alone.
|
||||
|
||||
Supporting = positive sentiment, Opposing = negative sentiment.
|
||||
Neutral/mixed signals are excluded from evidence lists.
|
||||
"""
|
||||
config = EvidenceRankConfig(max_refs=max_refs)
|
||||
return _rank_evidence_composite(signals, config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extract dominant catalysts and material risks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extract_catalysts_and_risks(
|
||||
impacts: list[ImpactRow],
|
||||
signals: list[WeightedSignal],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Return dominant catalyst types and material risks weighted by signal strength.
|
||||
|
||||
Catalysts are ranked by cumulative weight. Risks are deduplicated and
|
||||
ordered by the weight of the signal that surfaced them.
|
||||
"""
|
||||
catalyst_weights: dict[str, float] = {}
|
||||
risk_entries: list[tuple[float, str]] = []
|
||||
|
||||
# Build a lookup from document_id to combined weight
|
||||
weight_by_doc = {s.document_id: s.weight.combined * s.impact_score for s in signals}
|
||||
|
||||
for imp in impacts:
|
||||
w = weight_by_doc.get(imp.document_id, 0.0)
|
||||
if w <= 0.0:
|
||||
continue
|
||||
catalyst_weights[imp.catalyst_type] = catalyst_weights.get(imp.catalyst_type, 0.0) + w
|
||||
for risk in imp.risks:
|
||||
risk_entries.append((w, risk))
|
||||
|
||||
# Top catalysts by cumulative weight
|
||||
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
|
||||
catalysts = [cat for cat, _ in sorted_catalysts[:5]]
|
||||
|
||||
# Deduplicated risks ordered by weight
|
||||
seen_risks: set[str] = set()
|
||||
risks: list[str] = []
|
||||
risk_entries.sort(key=lambda x: x[0], reverse=True)
|
||||
for _, risk_text in risk_entries:
|
||||
normalized = risk_text.strip().lower()
|
||||
if normalized not in seen_risks:
|
||||
seen_risks.add(normalized)
|
||||
risks.append(risk_text.strip())
|
||||
if len(risks) >= 5:
|
||||
break
|
||||
|
||||
return catalysts, risks
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compute trend confidence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_trend_confidence(
|
||||
signals: list[WeightedSignal],
|
||||
contradiction_score: float,
|
||||
) -> float:
|
||||
"""Derive an overall confidence for the trend summary.
|
||||
|
||||
Confidence is based on:
|
||||
- Number of contributing signals (more = higher base)
|
||||
- Average extraction confidence of contributing signals
|
||||
- Contradiction penalty (high contradiction lowers confidence)
|
||||
|
||||
Returns a value in [0, 1].
|
||||
"""
|
||||
if not signals:
|
||||
return 0.0
|
||||
|
||||
active = [s for s in signals if s.weight.combined > 0]
|
||||
if not active:
|
||||
return 0.0
|
||||
|
||||
# Base confidence from signal count (diminishing returns)
|
||||
count_factor = min(len(active) / 20.0, 1.0)
|
||||
|
||||
# Average extraction confidence (from the confidence_gate — if gated,
|
||||
# the signal wouldn't be in active list, so we use the raw confidence
|
||||
# from the weight breakdown).
|
||||
avg_conf = sum(s.weight.credibility for s in active) / len(active)
|
||||
|
||||
# Contradiction penalty
|
||||
contradiction_penalty = contradiction_score * 0.4
|
||||
|
||||
confidence = (0.4 * count_factor + 0.6 * avg_conf) - contradiction_penalty
|
||||
return round(max(0.0, min(1.0, confidence)), 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assemble a TrendSummary from components
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssembledTrend:
|
||||
"""A trend summary paired with its detailed evidence rankings."""
|
||||
|
||||
summary: TrendSummary
|
||||
supporting_evidence: list[RankedEvidence]
|
||||
opposing_evidence: list[RankedEvidence]
|
||||
|
||||
|
||||
def assemble_trend_summary(
|
||||
ticker: str,
|
||||
window: str,
|
||||
signals: list[WeightedSignal],
|
||||
impacts: list[ImpactRow],
|
||||
market_ctx: Any | None = None,
|
||||
max_evidence: int = MAX_EVIDENCE_REFS,
|
||||
reference_time: datetime | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Build a complete TrendSummary from weighted signals and impact records."""
|
||||
result = assemble_trend_with_evidence(
|
||||
ticker, window, signals, impacts, market_ctx, max_evidence, reference_time,
|
||||
)
|
||||
return result.summary
|
||||
|
||||
|
||||
def assemble_trend_with_evidence(
|
||||
ticker: str,
|
||||
window: str,
|
||||
signals: list[WeightedSignal],
|
||||
impacts: list[ImpactRow],
|
||||
market_ctx: Any | None = None,
|
||||
max_evidence: int = MAX_EVIDENCE_REFS,
|
||||
reference_time: datetime | None = None,
|
||||
) -> AssembledTrend:
|
||||
"""Build a TrendSummary and return detailed evidence rankings for persistence."""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
avg_sentiment = weighted_sentiment_average(signals)
|
||||
|
||||
# Run full contradiction detection (Requirement 6.4)
|
||||
catalyst_entries = [
|
||||
CatalystEntry(document_id=imp.document_id, catalyst_type=imp.catalyst_type)
|
||||
for imp in impacts
|
||||
]
|
||||
contradiction_result = detect_contradictions(signals, catalyst_entries)
|
||||
contradiction = contradiction_result.score
|
||||
|
||||
direction = derive_trend_direction(avg_sentiment, contradiction)
|
||||
confidence = compute_trend_confidence(signals, contradiction)
|
||||
|
||||
# Get detailed evidence rankings for persistence
|
||||
config = EvidenceRankConfig(max_refs=max_evidence)
|
||||
supporting_ranked, opposing_ranked = rank_evidence_detailed(signals, config)
|
||||
|
||||
supporting = [r.document_id for r in supporting_ranked]
|
||||
opposing = [r.document_id for r in opposing_ranked]
|
||||
|
||||
catalysts, risks = extract_catalysts_and_risks(impacts, signals)
|
||||
|
||||
# Trend strength: absolute value of weighted sentiment, clamped to [0, 1]
|
||||
strength = round(min(abs(avg_sentiment), 1.0), 4)
|
||||
|
||||
summary = TrendSummary(
|
||||
entity_type="company",
|
||||
entity_id=ticker,
|
||||
window=TrendWindow(window),
|
||||
trend_direction=direction,
|
||||
trend_strength=strength,
|
||||
confidence=confidence,
|
||||
top_supporting_evidence=supporting,
|
||||
top_opposing_evidence=opposing,
|
||||
dominant_catalysts=catalysts,
|
||||
material_risks=risks,
|
||||
contradiction_score=contradiction,
|
||||
disagreement_details=contradiction_result.details,
|
||||
market_context=market_ctx,
|
||||
generated_at=reference_time,
|
||||
)
|
||||
|
||||
return AssembledTrend(
|
||||
summary=summary,
|
||||
supporting_evidence=supporting_ranked,
|
||||
opposing_evidence=opposing_ranked,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist trend summary to PostgreSQL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UPSERT_TREND = """
|
||||
INSERT INTO trend_windows (
|
||||
entity_type, entity_id, window, trend_direction, trend_strength,
|
||||
confidence, top_supporting_evidence, top_opposing_evidence,
|
||||
dominant_catalysts, material_risks, contradiction_score,
|
||||
disagreement_details, market_context, generated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7::jsonb, $8::jsonb,
|
||||
$9::jsonb, $10::jsonb, $11,
|
||||
$12::jsonb, $13::jsonb, $14
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
|
||||
async def persist_trend_summary(
|
||||
pool: asyncpg.Pool,
|
||||
summary: TrendSummary,
|
||||
) -> str:
|
||||
"""Insert a trend summary row and return its UUID."""
|
||||
row = await pool.fetchrow(
|
||||
_UPSERT_TREND,
|
||||
summary.entity_type,
|
||||
summary.entity_id,
|
||||
summary.window.value,
|
||||
summary.trend_direction.value,
|
||||
summary.trend_strength,
|
||||
summary.confidence,
|
||||
json.dumps(summary.top_supporting_evidence),
|
||||
json.dumps(summary.top_opposing_evidence),
|
||||
json.dumps(summary.dominant_catalysts),
|
||||
json.dumps(summary.material_risks),
|
||||
summary.contradiction_score,
|
||||
json.dumps([d.model_dump() for d in summary.disagreement_details]),
|
||||
json.dumps(summary.market_context.model_dump() if summary.market_context else {}),
|
||||
summary.generated_at,
|
||||
)
|
||||
return str(row["id"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist evidence mappings to trend_evidence table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_EVIDENCE = """
|
||||
INSERT INTO trend_evidence (
|
||||
trend_window_id, document_id, evidence_type,
|
||||
rank_score, weight_component, impact_component,
|
||||
recency_component, confidence_component, sentiment_value
|
||||
) VALUES (
|
||||
$1, $2::uuid, $3,
|
||||
$4, $5, $6,
|
||||
$7, $8, $9
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
async def persist_trend_evidence(
|
||||
pool: asyncpg.Pool,
|
||||
trend_window_id: str,
|
||||
supporting: list[RankedEvidence],
|
||||
opposing: list[RankedEvidence],
|
||||
) -> int:
|
||||
"""Insert evidence mapping rows for a trend window. Returns count inserted."""
|
||||
rows: list[tuple[str, str, str, float, float, float, float, float, float]] = []
|
||||
for ev in supporting:
|
||||
rows.append((
|
||||
trend_window_id, ev.document_id, "supporting",
|
||||
ev.rank_score, ev.weight_component, ev.impact_component,
|
||||
ev.recency_component, ev.confidence_component, ev.sentiment_value,
|
||||
))
|
||||
for ev in opposing:
|
||||
rows.append((
|
||||
trend_window_id, ev.document_id, "opposing",
|
||||
ev.rank_score, ev.weight_component, ev.impact_component,
|
||||
ev.recency_component, ev.confidence_component, ev.sentiment_value,
|
||||
))
|
||||
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
await pool.executemany(_INSERT_EVIDENCE, rows)
|
||||
return len(rows)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main aggregation entry point for a single ticker + window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def aggregate_company_window(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
config: AggregationConfig | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a trend summary for one ticker and one window.
|
||||
|
||||
Steps:
|
||||
1. Determine the time range for the window.
|
||||
2. Fetch document impact records from PostgreSQL.
|
||||
3. Fetch market context for the ticker.
|
||||
4. Build weighted signals using the scoring module.
|
||||
5. Assemble the TrendSummary.
|
||||
6. Persist to trend_windows table.
|
||||
|
||||
Returns the assembled TrendSummary.
|
||||
"""
|
||||
cfg = config or AggregationConfig()
|
||||
scoring_cfg = cfg.effective_scoring()
|
||||
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
_agg_start = time.monotonic()
|
||||
duration = WINDOW_DURATIONS.get(window, timedelta(days=7))
|
||||
window_start = reference_time - duration
|
||||
|
||||
# 1. Fetch impact records
|
||||
impacts = await fetch_impact_records(pool, ticker, window_start, reference_time)
|
||||
|
||||
# 2. Fetch market context
|
||||
market_ctx = await fetch_market_context(pool, ticker, window, reference_time)
|
||||
|
||||
# 3. Build weighted signals
|
||||
signals = build_weighted_signals(
|
||||
impacts, reference_time, window, market_ctx, scoring_cfg,
|
||||
)
|
||||
|
||||
# 4. Assemble trend summary with evidence details
|
||||
assembled = assemble_trend_with_evidence(
|
||||
ticker=ticker,
|
||||
window=window,
|
||||
signals=signals,
|
||||
impacts=impacts,
|
||||
market_ctx=market_ctx if market_ctx.has_data else None,
|
||||
max_evidence=cfg.max_evidence,
|
||||
reference_time=reference_time,
|
||||
)
|
||||
summary = assembled.summary
|
||||
|
||||
# 5. Persist trend window
|
||||
trend_id = await persist_trend_summary(pool, summary)
|
||||
|
||||
# 6. Persist evidence mappings
|
||||
evidence_count = await persist_trend_evidence(
|
||||
pool, trend_id,
|
||||
assembled.supporting_evidence,
|
||||
assembled.opposing_evidence,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Persisted trend %s for %s/%s: direction=%s strength=%.3f confidence=%.3f signals=%d evidence=%d",
|
||||
trend_id, ticker, window, summary.trend_direction.value,
|
||||
summary.trend_strength, summary.confidence, len(signals), evidence_count,
|
||||
)
|
||||
|
||||
# Prometheus metrics
|
||||
AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc()
|
||||
AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals))
|
||||
AGGREGATION_CONTRADICTION_SCORE.observe(summary.contradiction_score)
|
||||
AGGREGATION_DURATION.labels(window=window).observe(time.monotonic() - _agg_start)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aggregate all windows for a single ticker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def aggregate_company(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
reference_time: datetime | None = None,
|
||||
config: AggregationConfig | None = None,
|
||||
) -> list[TrendSummary]:
|
||||
"""Compute trend summaries for all configured windows for a ticker."""
|
||||
cfg = config or AggregationConfig()
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
summaries: list[TrendSummary] = []
|
||||
for window in cfg.effective_windows():
|
||||
summary = await aggregate_company_window(
|
||||
pool, ticker, window, reference_time, cfg,
|
||||
)
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
|
||||
+1507
-1
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,268 @@
|
||||
"""Ollama client wrapper using structured output format.
|
||||
|
||||
Sends documents to a local Ollama instance via the /api/chat endpoint
|
||||
with the ``format`` parameter set to the extraction JSON schema, ensuring
|
||||
the model returns schema-compliant JSON.
|
||||
|
||||
Includes retry logic for invalid or incomplete model responses with
|
||||
exponential backoff, error classification, and full audit preservation.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from services.extractor.prompts import (
|
||||
build_extraction_prompt,
|
||||
get_json_schema,
|
||||
get_prompt_metadata,
|
||||
)
|
||||
from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction
|
||||
from services.shared.config import OllamaConfig
|
||||
|
||||
logger = logging.getLogger("ollama_client")
|
||||
|
||||
# Errors that should NOT be retried — the request itself is bad.
|
||||
_NON_RETRYABLE_ERRORS = frozenset({
|
||||
"http_400",
|
||||
"http_401",
|
||||
"http_403",
|
||||
"http_404",
|
||||
"http_422",
|
||||
})
|
||||
|
||||
|
||||
def _is_retryable(error: str | None) -> bool:
|
||||
"""Determine whether an extraction error warrants a retry."""
|
||||
if error is None:
|
||||
return False
|
||||
return error not in _NON_RETRYABLE_ERRORS
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionAttempt:
|
||||
"""Record of a single extraction attempt for audit."""
|
||||
|
||||
raw_output: str = ""
|
||||
validation: ValidationReport | None = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
model: str = ""
|
||||
retryable: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResponse:
|
||||
"""Full response from an extraction call, including all attempts."""
|
||||
|
||||
success: bool = False
|
||||
result: ExtractionResult | None = None
|
||||
attempts: list[ExtractionAttempt] = field(default_factory=list)
|
||||
prompt_metadata: dict[str, str] = field(default_factory=dict)
|
||||
model: str = ""
|
||||
total_duration_ms: int = 0
|
||||
|
||||
|
||||
def _compute_backoff(
|
||||
attempt_num: int,
|
||||
base_delay: float,
|
||||
max_delay: float,
|
||||
multiplier: float,
|
||||
) -> float:
|
||||
"""Compute exponential backoff delay for a given attempt number."""
|
||||
delay = base_delay * (multiplier ** attempt_num)
|
||||
return min(delay, max_delay)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Async client for Ollama structured extraction.
|
||||
|
||||
Usage::
|
||||
|
||||
config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b")
|
||||
client = OllamaClient(config)
|
||||
response = await client.extract(
|
||||
document_text="Apple reported record earnings...",
|
||||
document_type="article",
|
||||
document_id="abc-123",
|
||||
)
|
||||
if response.success:
|
||||
print(response.result)
|
||||
"""
|
||||
|
||||
_config: OllamaConfig
|
||||
_max_retries: int
|
||||
_base_delay: float
|
||||
_max_delay: float
|
||||
_backoff_multiplier: float
|
||||
_owns_client: bool
|
||||
_http: httpx.AsyncClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OllamaConfig,
|
||||
max_retries: int | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._max_retries = max_retries if max_retries is not None else config.max_retries
|
||||
self._base_delay = config.retry_base_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._backoff_multiplier = config.retry_backoff_multiplier
|
||||
self._owns_client = http_client is None
|
||||
self._http = http_client or httpx.AsyncClient(timeout=config.timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying HTTP client if we own it."""
|
||||
if self._owns_client:
|
||||
await self._http.aclose()
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
document_text: str,
|
||||
document_type: str = "article",
|
||||
document_id: str = "",
|
||||
known_tickers: list[str] | None = None,
|
||||
) -> ExtractionResponse:
|
||||
"""Send a document to Ollama for structured intelligence extraction.
|
||||
|
||||
Retries up to ``max_retries`` times when the model returns invalid
|
||||
or incomplete JSON. Uses exponential backoff between retries.
|
||||
Non-retryable errors (e.g. HTTP 400) stop retries immediately.
|
||||
Each attempt and its validation result are preserved for audit.
|
||||
|
||||
Args:
|
||||
document_text: Normalized text content of the document.
|
||||
document_type: One of article, filing, transcript, press_release.
|
||||
document_id: Optional document ID for traceability.
|
||||
known_tickers: Optional ticker hints for the model.
|
||||
|
||||
Returns:
|
||||
An ``ExtractionResponse`` with the parsed result on success.
|
||||
"""
|
||||
prompts = build_extraction_prompt(
|
||||
document_text=document_text,
|
||||
document_type=document_type,
|
||||
document_id=document_id,
|
||||
known_tickers=known_tickers,
|
||||
)
|
||||
json_schema = get_json_schema()
|
||||
prompt_meta = get_prompt_metadata()
|
||||
|
||||
response = ExtractionResponse(
|
||||
prompt_metadata=prompt_meta,
|
||||
model=self._config.model,
|
||||
)
|
||||
|
||||
total_start = time.monotonic()
|
||||
|
||||
for attempt_num in range(self._max_retries + 1):
|
||||
attempt = await self._call_ollama(prompts, json_schema, document_text)
|
||||
response.attempts.append(attempt)
|
||||
|
||||
if attempt.error is None and attempt.validation and attempt.validation.valid:
|
||||
response.success = True
|
||||
response.result = attempt.validation.parsed
|
||||
break
|
||||
|
||||
# Check if the error is non-retryable — stop immediately
|
||||
if not _is_retryable(attempt.error):
|
||||
attempt.retryable = False
|
||||
logger.warning(
|
||||
"Non-retryable error for doc %s: %s — stopping retries",
|
||||
document_id or "unknown",
|
||||
attempt.error,
|
||||
)
|
||||
break
|
||||
|
||||
if attempt_num < self._max_retries:
|
||||
delay = _compute_backoff(
|
||||
attempt_num,
|
||||
self._base_delay,
|
||||
self._max_delay,
|
||||
self._backoff_multiplier,
|
||||
)
|
||||
logger.warning(
|
||||
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1,
|
||||
self._max_retries + 1,
|
||||
document_id or "unknown",
|
||||
attempt.error or "validation failed",
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
|
||||
return response
|
||||
|
||||
async def _call_ollama(
|
||||
self,
|
||||
prompts: dict[str, str],
|
||||
json_schema: dict[str, object],
|
||||
document_text: str = "",
|
||||
) -> ExtractionAttempt:
|
||||
"""Make a single call to the Ollama /api/chat endpoint."""
|
||||
attempt = ExtractionAttempt(model=self._config.model)
|
||||
start = time.monotonic()
|
||||
|
||||
payload = {
|
||||
"model": self._config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": prompts["system"]},
|
||||
{"role": "user", "content": prompts["user"]},
|
||||
],
|
||||
"format": json_schema,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(
|
||||
f"{self._config.base_url}/api/chat",
|
||||
json=payload,
|
||||
)
|
||||
_ = resp.raise_for_status()
|
||||
except httpx.TimeoutException:
|
||||
attempt.error = "timeout"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPStatusError as exc:
|
||||
attempt.error = f"http_{exc.response.status_code}"
|
||||
attempt.retryable = _is_retryable(attempt.error)
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPError as exc:
|
||||
attempt.error = f"connection_error: {exc}"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Parse the Ollama response envelope
|
||||
try:
|
||||
body: dict[str, object] = resp.json()
|
||||
except json.JSONDecodeError:
|
||||
attempt.error = "invalid_response_json"
|
||||
attempt.raw_output = resp.text
|
||||
return attempt
|
||||
|
||||
msg = body.get("message")
|
||||
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
|
||||
attempt.raw_output = content
|
||||
|
||||
if not content:
|
||||
attempt.error = "empty_model_response"
|
||||
return attempt
|
||||
|
||||
# Validate against extraction schema
|
||||
attempt.validation = validate_extraction(content, document_text=document_text)
|
||||
if not attempt.validation.valid:
|
||||
attempt.error = "; ".join(attempt.validation.errors)
|
||||
|
||||
return attempt
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Extractor worker entrypoint - polls Redis for extraction jobs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key
|
||||
|
||||
logger = logging.getLogger("extractor_main")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||||
minio_client = Minio(
|
||||
config.minio.endpoint,
|
||||
access_key=config.minio.access_key,
|
||||
secret_key=config.minio.secret_key,
|
||||
secure=config.minio.secure,
|
||||
)
|
||||
ollama = OllamaClient(config.ollama)
|
||||
|
||||
import json
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
logger.info("Extractor worker started, polling %s", queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
payload = raw
|
||||
job = json.loads(payload)
|
||||
document_id = job.get("document_id", "")
|
||||
ticker = job.get("ticker", "")
|
||||
text = job.get("text", "")
|
||||
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
try:
|
||||
extraction_response = await ollama.extract(text)
|
||||
await persist_extraction(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
extraction_response=extraction_response,
|
||||
document_text_length=len(text),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Extraction failed for doc %s", document_id)
|
||||
finally:
|
||||
await pool.close()
|
||||
await redis_client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Model performance metrics collection and persistence.
|
||||
|
||||
Tracks extraction success/failure rates, latency percentiles, retry counts,
|
||||
validation error distributions, confidence scores, and token usage estimates.
|
||||
Metrics are persisted to PostgreSQL for operational dashboards and published
|
||||
to the analytical lake for Trino/Superset queries.
|
||||
|
||||
Requirements: 5.2, 5.4, 12.1, 12.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.extractor.client import ExtractionResponse
|
||||
|
||||
logger = logging.getLogger("extractor_metrics")
|
||||
|
||||
# Rough token estimate: ~4 chars per token for English text
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionMetrics:
|
||||
"""Metrics extracted from a single extraction run."""
|
||||
|
||||
document_id: str = ""
|
||||
ticker: str = ""
|
||||
model_name: str = ""
|
||||
prompt_version: str = ""
|
||||
schema_version: str = ""
|
||||
success: bool = False
|
||||
attempt_count: int = 0
|
||||
total_duration_ms: int = 0
|
||||
first_attempt_duration_ms: int = 0
|
||||
final_attempt_duration_ms: int = 0
|
||||
confidence: float = 0.0
|
||||
validation_status: str = "unknown"
|
||||
validation_error_count: int = 0
|
||||
validation_warning_count: int = 0
|
||||
validation_errors: list[str] = field(default_factory=list)
|
||||
retry_count: int = 0
|
||||
input_token_estimate: int = 0
|
||||
output_token_estimate: int = 0
|
||||
company_count: int = 0
|
||||
recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
def collect_metrics(
|
||||
extraction_response: ExtractionResponse,
|
||||
*,
|
||||
document_id: str = "",
|
||||
ticker: str = "",
|
||||
document_text_length: int = 0,
|
||||
) -> ExtractionMetrics:
|
||||
"""Collect metrics from an ExtractionResponse.
|
||||
|
||||
Args:
|
||||
extraction_response: The full response from OllamaClient.extract().
|
||||
document_id: UUID of the source document.
|
||||
ticker: Primary ticker symbol.
|
||||
document_text_length: Length of the input document text in characters.
|
||||
|
||||
Returns:
|
||||
An ExtractionMetrics dataclass with all computed fields.
|
||||
"""
|
||||
attempts = extraction_response.attempts
|
||||
first_dur = attempts[0].duration_ms if attempts else 0
|
||||
final_dur = attempts[-1].duration_ms if attempts else 0
|
||||
|
||||
# Gather validation info from the final attempt
|
||||
final_attempt = attempts[-1] if attempts else None
|
||||
val_errors: list[str] = []
|
||||
val_warnings: list[str] = []
|
||||
if final_attempt and final_attempt.validation:
|
||||
val_errors = final_attempt.validation.errors
|
||||
val_warnings = final_attempt.validation.warnings
|
||||
|
||||
# Determine validation status
|
||||
if extraction_response.success:
|
||||
validation_status = "valid"
|
||||
elif attempts:
|
||||
validation_status = "failed"
|
||||
else:
|
||||
validation_status = "unknown"
|
||||
|
||||
# Confidence from the result, or 0 if failed
|
||||
confidence = 0.0
|
||||
company_count = 0
|
||||
if extraction_response.result:
|
||||
confidence = extraction_response.result.confidence
|
||||
company_count = len(extraction_response.result.companies)
|
||||
|
||||
# Token estimates
|
||||
input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0
|
||||
output_tokens = 0
|
||||
if final_attempt and final_attempt.raw_output:
|
||||
output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN
|
||||
|
||||
return ExtractionMetrics(
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
model_name=extraction_response.model,
|
||||
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
||||
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
||||
success=extraction_response.success,
|
||||
attempt_count=len(attempts),
|
||||
total_duration_ms=extraction_response.total_duration_ms,
|
||||
first_attempt_duration_ms=first_dur,
|
||||
final_attempt_duration_ms=final_dur,
|
||||
confidence=confidence,
|
||||
validation_status=validation_status,
|
||||
validation_error_count=len(val_errors),
|
||||
validation_warning_count=len(val_warnings),
|
||||
validation_errors=val_errors,
|
||||
retry_count=max(0, len(attempts) - 1),
|
||||
input_token_estimate=input_tokens,
|
||||
output_token_estimate=output_tokens,
|
||||
company_count=company_count,
|
||||
)
|
||||
|
||||
|
||||
async def persist_metrics(
|
||||
pool: asyncpg.Pool,
|
||||
metrics: ExtractionMetrics,
|
||||
) -> str:
|
||||
"""Persist extraction metrics to the model_performance_metrics table.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
metrics: Collected metrics from an extraction run.
|
||||
|
||||
Returns:
|
||||
The UUID of the inserted metrics row.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO model_performance_metrics
|
||||
(document_id, ticker, model_name, prompt_version, schema_version,
|
||||
success, attempt_count, total_duration_ms,
|
||||
first_attempt_duration_ms, final_attempt_duration_ms,
|
||||
confidence, validation_status, validation_error_count,
|
||||
validation_warning_count, validation_errors, retry_count,
|
||||
input_token_estimate, output_token_estimate, company_count,
|
||||
recorded_at)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20)
|
||||
RETURNING id""",
|
||||
metrics.document_id,
|
||||
metrics.ticker,
|
||||
metrics.model_name,
|
||||
metrics.prompt_version,
|
||||
metrics.schema_version,
|
||||
metrics.success,
|
||||
metrics.attempt_count,
|
||||
metrics.total_duration_ms,
|
||||
metrics.first_attempt_duration_ms,
|
||||
metrics.final_attempt_duration_ms,
|
||||
metrics.confidence,
|
||||
metrics.validation_status,
|
||||
metrics.validation_error_count,
|
||||
metrics.validation_warning_count,
|
||||
json.dumps(metrics.validation_errors),
|
||||
metrics.retry_count,
|
||||
metrics.input_token_estimate,
|
||||
metrics.output_token_estimate,
|
||||
metrics.company_count,
|
||||
metrics.recorded_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d",
|
||||
row_id, metrics.document_id, metrics.success,
|
||||
metrics.total_duration_ms, metrics.retry_count,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
async def get_model_performance_summary(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
model_name: str | None = None,
|
||||
hours: int = 24,
|
||||
) -> dict[str, object]:
|
||||
"""Query aggregated model performance metrics for dashboards.
|
||||
|
||||
Returns a summary dict with success rate, avg latency, retry rate,
|
||||
confidence distribution, and error breakdown for the given time window.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
model_name: Optional filter by model name.
|
||||
hours: Lookback window in hours (default 24).
|
||||
|
||||
Returns:
|
||||
Dict with aggregated performance metrics.
|
||||
"""
|
||||
model_filter = "AND model_name = $2" if model_name else ""
|
||||
params: list[object] = [hours]
|
||||
if model_name:
|
||||
params.append(model_name)
|
||||
|
||||
row = await pool.fetchrow(
|
||||
f"""SELECT
|
||||
COUNT(*) AS total_extractions,
|
||||
COUNT(*) FILTER (WHERE success) AS successful,
|
||||
COUNT(*) FILTER (WHERE NOT success) AS failed,
|
||||
ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms,
|
||||
ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms,
|
||||
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms,
|
||||
ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms,
|
||||
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
|
||||
ROUND(AVG(confidence)::numeric, 3) AS avg_confidence,
|
||||
SUM(input_token_estimate) AS total_input_tokens,
|
||||
SUM(output_token_estimate) AS total_output_tokens,
|
||||
ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc,
|
||||
ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors,
|
||||
ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings
|
||||
FROM model_performance_metrics
|
||||
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1
|
||||
{model_filter}""",
|
||||
*params,
|
||||
)
|
||||
|
||||
if not row or row["total_extractions"] == 0:
|
||||
return {"total_extractions": 0, "success_rate": 0.0}
|
||||
|
||||
total = row["total_extractions"]
|
||||
successful = row["successful"]
|
||||
|
||||
return {
|
||||
"total_extractions": total,
|
||||
"successful": successful,
|
||||
"failed": row["failed"],
|
||||
"success_rate": round(successful / total, 4) if total > 0 else 0.0,
|
||||
"avg_duration_ms": float(row["avg_duration_ms"] or 0),
|
||||
"p50_duration_ms": float(row["p50_duration_ms"] or 0),
|
||||
"p95_duration_ms": float(row["p95_duration_ms"] or 0),
|
||||
"p99_duration_ms": float(row["p99_duration_ms"] or 0),
|
||||
"avg_retries": float(row["avg_retries"] or 0),
|
||||
"avg_confidence": float(row["avg_confidence"] or 0),
|
||||
"total_input_tokens": int(row["total_input_tokens"] or 0),
|
||||
"total_output_tokens": int(row["total_output_tokens"] or 0),
|
||||
"avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0),
|
||||
"avg_validation_errors": float(row["avg_validation_errors"] or 0),
|
||||
"avg_validation_warnings": float(row["avg_validation_warnings"] or 0),
|
||||
"hours": hours,
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Extraction prompt templates with anti-hallucination instructions.
|
||||
|
||||
Builds structured prompts for Ollama document intelligence extraction.
|
||||
Each prompt includes the target JSON schema, anti-hallucination rules,
|
||||
and document-type-specific guidance.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from services.extractor.schemas import generate_json_schema, SCHEMA_VERSION
|
||||
from services.shared.schemas import (
|
||||
DocumentType,
|
||||
)
|
||||
|
||||
PROMPT_VERSION = "document-intel-v1"
|
||||
|
||||
# --- JSON schema for structured output (generated from Pydantic models) ---
|
||||
|
||||
EXTRACTION_JSON_SCHEMA: dict[str, Any] = generate_json_schema()
|
||||
|
||||
# --- Anti-hallucination system prompt ---
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are a financial document analysis system. You extract structured intelligence \
|
||||
from financial documents into JSON.
|
||||
|
||||
STRICT RULES — VIOLATIONS WILL INVALIDATE YOUR OUTPUT:
|
||||
|
||||
1. ONLY extract information explicitly stated in the document text provided.
|
||||
2. NEVER fabricate facts, quotes, numbers, dates, or company names.
|
||||
3. NEVER infer information that is not directly supported by the text.
|
||||
4. If the document does not mention a company, do NOT include that company.
|
||||
5. If the document is ambiguous about sentiment or impact, use "neutral" or "mixed" \
|
||||
and set confidence lower.
|
||||
6. evidence_spans MUST be short verbatim quotes copied from the document. \
|
||||
Do NOT paraphrase or invent quotes.
|
||||
7. key_facts MUST be directly stated in the document. Do NOT add external knowledge.
|
||||
8. If you are uncertain about any field, lower the confidence score and add a warning \
|
||||
to extraction_warnings.
|
||||
9. If the document text is too short, garbled, or uninformative, return an empty \
|
||||
companies array, set confidence below 0.3, and add "insufficient_content" to warnings.
|
||||
10. Return ONLY valid JSON matching the provided schema. No commentary, no markdown fences."""
|
||||
|
||||
# --- Document-type-specific guidance ---
|
||||
|
||||
_DOCTYPE_GUIDANCE: dict[str, str] = {
|
||||
DocumentType.ARTICLE: (
|
||||
"This is a news article. Focus on reported facts, quoted sources, and stated "
|
||||
"analyst opinions. Distinguish between the journalist's framing and actual "
|
||||
"company developments. Do not treat speculative language as confirmed fact."
|
||||
),
|
||||
DocumentType.FILING: (
|
||||
"This is a regulatory filing (e.g. SEC 10-K, 10-Q, 8-K). Extract concrete "
|
||||
"financial figures, risk factors, and material events as stated. Filings use "
|
||||
"precise legal language — preserve that precision in your extraction."
|
||||
),
|
||||
DocumentType.TRANSCRIPT: (
|
||||
"This is an earnings call or event transcript. Distinguish between management "
|
||||
"forward-looking statements and reported results. Flag forward-looking language "
|
||||
"as lower confidence. Extract specific guidance numbers when stated."
|
||||
),
|
||||
DocumentType.PRESS_RELEASE: (
|
||||
"This is a company press release. Be aware that press releases are promotional. "
|
||||
"Extract stated facts and figures but note that sentiment may be biased positive. "
|
||||
"Look for concrete metrics rather than marketing language."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_doctype_guidance(document_type: str) -> str:
|
||||
"""Return document-type-specific extraction guidance."""
|
||||
return _DOCTYPE_GUIDANCE.get(document_type, _DOCTYPE_GUIDANCE[DocumentType.ARTICLE])
|
||||
|
||||
|
||||
# --- Prompt builder ---
|
||||
|
||||
def build_extraction_prompt(
|
||||
document_text: str,
|
||||
document_type: str = DocumentType.ARTICLE,
|
||||
known_tickers: list[str] | None = None,
|
||||
document_id: str = "",
|
||||
) -> dict[str, str]:
|
||||
"""Build system and user prompts for Ollama structured extraction.
|
||||
|
||||
Args:
|
||||
document_text: Normalized text content of the document.
|
||||
document_type: One of the DocumentType enum values.
|
||||
known_tickers: Optional list of tickers the document may reference.
|
||||
Helps the model focus but does NOT mean all tickers are relevant.
|
||||
document_id: Optional document ID for traceability.
|
||||
|
||||
Returns:
|
||||
Dict with 'system' and 'user' prompt strings.
|
||||
"""
|
||||
doctype_guidance = _get_doctype_guidance(document_type)
|
||||
|
||||
ticker_hint = ""
|
||||
if known_tickers:
|
||||
tickers_str = ", ".join(known_tickers)
|
||||
ticker_hint = (
|
||||
f"\nThe following tickers may be referenced in this document: {tickers_str}\n"
|
||||
"Only include a ticker in your output if the document actually discusses that company. "
|
||||
"Do NOT include a ticker just because it appears in this hint."
|
||||
)
|
||||
|
||||
schema_str = json.dumps(EXTRACTION_JSON_SCHEMA, indent=2)
|
||||
|
||||
doc_id_line = f"Document ID: {document_id}\n" if document_id else ""
|
||||
|
||||
user_prompt = f"""\
|
||||
Extract structured intelligence from the following document.
|
||||
|
||||
{doc_id_line}Document type: {document_type}
|
||||
{doctype_guidance}
|
||||
{ticker_hint}
|
||||
Your output MUST be a single JSON object conforming to this schema:
|
||||
{schema_str}
|
||||
|
||||
REMEMBER:
|
||||
- Only extract what is explicitly in the text below.
|
||||
- evidence_spans must be verbatim quotes from the text.
|
||||
- If the text is insufficient, return empty companies and low confidence.
|
||||
- Return ONLY the JSON object. No other text.
|
||||
|
||||
--- DOCUMENT TEXT ---
|
||||
{document_text}
|
||||
--- END DOCUMENT TEXT ---"""
|
||||
|
||||
return {
|
||||
"system": SYSTEM_PROMPT,
|
||||
"user": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
def get_prompt_metadata() -> dict[str, str]:
|
||||
"""Return metadata about the current prompt version for audit trails."""
|
||||
return {
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def get_json_schema() -> dict[str, Any]:
|
||||
"""Return the extraction JSON schema for Ollama structured output format parameter."""
|
||||
return EXTRACTION_JSON_SCHEMA
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Replay dataset loader and runner for deterministic extraction testing.
|
||||
|
||||
Loads archived document fixtures from JSON files, validates their expected
|
||||
extraction outputs against the current schema, and provides a runner that
|
||||
can compare live Ollama extraction results against expected baselines.
|
||||
|
||||
This enables:
|
||||
- Schema regression testing: verify expected outputs still pass validation
|
||||
- Prompt regression testing: detect drift when prompts or schemas change
|
||||
- End-to-end replay: run fixtures through a live Ollama and compare
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from services.extractor.schemas import (
|
||||
ExtractionResult,
|
||||
ValidationReport,
|
||||
get_schema_version,
|
||||
validate_extraction,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("extractor_replay")
|
||||
|
||||
FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayFixture:
|
||||
"""A single replay fixture loaded from disk."""
|
||||
|
||||
document_id: str
|
||||
document_type: str
|
||||
document_text: str
|
||||
known_tickers: list[str]
|
||||
expected_extraction: dict[str, Any]
|
||||
metadata: dict[str, str]
|
||||
source_path: str = ""
|
||||
|
||||
@property
|
||||
def expected_result(self) -> ExtractionResult:
|
||||
"""Parse expected_extraction into a validated ExtractionResult."""
|
||||
return ExtractionResult.model_validate(self.expected_extraction)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayValidationResult:
|
||||
"""Result of validating a single fixture against the current schema."""
|
||||
|
||||
fixture_id: str
|
||||
schema_valid: bool = False
|
||||
validation_report: ValidationReport | None = None
|
||||
schema_version: str = ""
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayComparisonResult:
|
||||
"""Result of comparing a live extraction against the expected baseline."""
|
||||
|
||||
fixture_id: str
|
||||
expected_companies: list[str] = field(default_factory=list)
|
||||
actual_companies: list[str] = field(default_factory=list)
|
||||
companies_match: bool = False
|
||||
expected_sentiment_map: dict[str, str] = field(default_factory=dict)
|
||||
actual_sentiment_map: dict[str, str] = field(default_factory=dict)
|
||||
sentiment_match: bool = False
|
||||
expected_catalyst_map: dict[str, str] = field(default_factory=dict)
|
||||
actual_catalyst_map: dict[str, str] = field(default_factory=dict)
|
||||
catalyst_match: bool = False
|
||||
actual_schema_valid: bool = False
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def load_fixture(path: Path) -> ReplayFixture:
|
||||
"""Load a single replay fixture from a JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the fixture JSON file.
|
||||
|
||||
Returns:
|
||||
A ReplayFixture with all fields populated.
|
||||
|
||||
Raises:
|
||||
ValueError: If the fixture is missing required fields.
|
||||
json.JSONDecodeError: If the file is not valid JSON.
|
||||
"""
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
required = {"document_id", "document_type", "document_text", "expected_extraction"}
|
||||
missing = required - set(data.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Fixture {path.name} missing required fields: {missing}")
|
||||
|
||||
return ReplayFixture(
|
||||
document_id=data["document_id"],
|
||||
document_type=data["document_type"],
|
||||
document_text=data["document_text"],
|
||||
known_tickers=data.get("known_tickers", []),
|
||||
expected_extraction=data["expected_extraction"],
|
||||
metadata=data.get("metadata", {}),
|
||||
source_path=str(path),
|
||||
)
|
||||
|
||||
|
||||
def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]:
|
||||
"""Load all replay fixtures from the fixtures directory.
|
||||
|
||||
Args:
|
||||
fixtures_dir: Override path to fixtures directory.
|
||||
Defaults to tests/replay_fixtures/.
|
||||
|
||||
Returns:
|
||||
List of loaded ReplayFixture objects, sorted by document_id.
|
||||
"""
|
||||
directory = fixtures_dir or FIXTURES_DIR
|
||||
if not directory.is_dir():
|
||||
logger.warning("Fixtures directory not found: %s", directory)
|
||||
return []
|
||||
|
||||
fixtures: list[ReplayFixture] = []
|
||||
for path in sorted(directory.glob("*.json")):
|
||||
try:
|
||||
fixture = load_fixture(path)
|
||||
fixtures.append(fixture)
|
||||
except (ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("Skipping invalid fixture %s: %s", path.name, exc)
|
||||
|
||||
logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory)
|
||||
return fixtures
|
||||
|
||||
|
||||
def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult:
|
||||
"""Validate a fixture's expected extraction against the current schema.
|
||||
|
||||
This is the core deterministic test: the expected output must still
|
||||
pass schema and semantic validation with the current code. If it
|
||||
doesn't, either the fixture is stale or the schema has regressed.
|
||||
|
||||
Args:
|
||||
fixture: The replay fixture to validate.
|
||||
|
||||
Returns:
|
||||
A ReplayValidationResult indicating pass/fail.
|
||||
"""
|
||||
result = ReplayValidationResult(
|
||||
fixture_id=fixture.document_id,
|
||||
schema_version=get_schema_version(),
|
||||
)
|
||||
|
||||
try:
|
||||
report = validate_extraction(
|
||||
fixture.expected_extraction,
|
||||
document_text=fixture.document_text,
|
||||
)
|
||||
result.validation_report = report
|
||||
result.schema_valid = report.valid
|
||||
except Exception as exc: # noqa: BLE001
|
||||
result.error = str(exc)
|
||||
result.schema_valid = False
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_all_fixtures(
|
||||
fixtures_dir: Path | None = None,
|
||||
) -> list[ReplayValidationResult]:
|
||||
"""Load and validate all fixtures against the current schema.
|
||||
|
||||
Args:
|
||||
fixtures_dir: Override path to fixtures directory.
|
||||
|
||||
Returns:
|
||||
List of validation results, one per fixture.
|
||||
"""
|
||||
fixtures = load_all_fixtures(fixtures_dir)
|
||||
return [validate_fixture(f) for f in fixtures]
|
||||
|
||||
|
||||
def compare_extraction(
|
||||
fixture: ReplayFixture,
|
||||
actual_result: ExtractionResult,
|
||||
) -> ReplayComparisonResult:
|
||||
"""Compare a live extraction result against the fixture's expected output.
|
||||
|
||||
Checks structural alignment (same companies detected, same sentiments,
|
||||
same catalyst types) rather than exact string equality, since LLM
|
||||
outputs vary in wording across runs.
|
||||
|
||||
Args:
|
||||
fixture: The replay fixture with expected output.
|
||||
actual_result: The ExtractionResult from a live extraction.
|
||||
|
||||
Returns:
|
||||
A ReplayComparisonResult with match details.
|
||||
"""
|
||||
expected = fixture.expected_result
|
||||
comparison = ReplayComparisonResult(fixture_id=fixture.document_id)
|
||||
|
||||
# Company ticker sets
|
||||
comparison.expected_companies = sorted(c.ticker for c in expected.companies)
|
||||
comparison.actual_companies = sorted(c.ticker for c in actual_result.companies)
|
||||
comparison.companies_match = (
|
||||
set(comparison.expected_companies) == set(comparison.actual_companies)
|
||||
)
|
||||
|
||||
# Sentiment by ticker
|
||||
comparison.expected_sentiment_map = {
|
||||
c.ticker: c.sentiment for c in expected.companies
|
||||
}
|
||||
comparison.actual_sentiment_map = {
|
||||
c.ticker: c.sentiment for c in actual_result.companies
|
||||
}
|
||||
comparison.sentiment_match = (
|
||||
comparison.expected_sentiment_map == comparison.actual_sentiment_map
|
||||
)
|
||||
|
||||
# Catalyst type by ticker
|
||||
comparison.expected_catalyst_map = {
|
||||
c.ticker: c.catalyst_type for c in expected.companies
|
||||
}
|
||||
comparison.actual_catalyst_map = {
|
||||
c.ticker: c.catalyst_type for c in actual_result.companies
|
||||
}
|
||||
comparison.catalyst_match = (
|
||||
comparison.expected_catalyst_map == comparison.actual_catalyst_map
|
||||
)
|
||||
|
||||
# Schema validity of actual result
|
||||
actual_report = validate_extraction(
|
||||
actual_result.model_dump(mode="json"),
|
||||
document_text=fixture.document_text,
|
||||
)
|
||||
comparison.actual_schema_valid = actual_report.valid
|
||||
if actual_report.warnings:
|
||||
comparison.warnings = actual_report.warnings
|
||||
|
||||
if not comparison.companies_match:
|
||||
comparison.warnings.append(
|
||||
f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}"
|
||||
)
|
||||
|
||||
return comparison
|
||||
@@ -0,0 +1,316 @@
|
||||
"""JSON schema definitions for document intelligence extraction.
|
||||
|
||||
Generates Ollama-compatible JSON schemas from Pydantic models so the
|
||||
extraction contract stays in sync with the shared data models. Also
|
||||
provides schema validation and semantic validation helpers.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.shared.schemas import (
|
||||
CatalystType,
|
||||
Sentiment,
|
||||
)
|
||||
|
||||
SCHEMA_VERSION = "2.0.0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic model that mirrors the Ollama extraction output contract.
|
||||
# This is the *response* shape we ask the model to produce — it intentionally
|
||||
# omits server-side fields like document_id, source_credibility, and model
|
||||
# metadata that are attached after extraction.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CompanyExtractionItem(BaseModel):
|
||||
"""Per-company extraction output expected from the model.
|
||||
|
||||
All fields are required (no defaults) so the generated JSON schema
|
||||
forces the model to produce every field explicitly.
|
||||
"""
|
||||
|
||||
ticker: str = Field(description="Stock ticker symbol mentioned in the document.")
|
||||
company_name: str = Field(description="Full company name as referenced in the document.")
|
||||
relevance: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="How relevant the document is to this company. 0=tangential, 1=primary subject.",
|
||||
)
|
||||
sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.")
|
||||
impact_score: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Estimated magnitude of impact. 0=negligible, 1=highly material.",
|
||||
)
|
||||
impact_horizon: str = Field(
|
||||
description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus",
|
||||
)
|
||||
catalyst_type: CatalystType = Field(description="Primary catalyst category.")
|
||||
key_facts: list[str] = Field(
|
||||
description="Facts explicitly stated in the document. Do NOT infer or fabricate.",
|
||||
)
|
||||
risks: list[str] = Field(
|
||||
description="Risks explicitly mentioned in the document.",
|
||||
)
|
||||
evidence_spans: list[str] = Field(
|
||||
description="Short verbatim quotes from the document supporting the analysis.",
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
"""Top-level structured output the model must return.
|
||||
|
||||
All fields are required (no defaults) so the generated JSON schema
|
||||
forces the model to produce every field explicitly.
|
||||
"""
|
||||
|
||||
summary: str = Field(
|
||||
description="A concise 1-3 sentence summary of the document's main point.",
|
||||
)
|
||||
companies: list[CompanyExtractionItem] = Field(
|
||||
description="Per-company intelligence extracted from the document.",
|
||||
)
|
||||
macro_themes: list[str] = Field(
|
||||
description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).",
|
||||
)
|
||||
novelty_score: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="How novel or surprising the information is. 0=routine, 1=highly novel.",
|
||||
)
|
||||
confidence: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.",
|
||||
)
|
||||
extraction_warnings: list[str] = Field(
|
||||
description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_json_schema() -> dict[str, Any]:
|
||||
"""Generate the JSON schema from the Pydantic model.
|
||||
|
||||
Returns a plain JSON Schema dict suitable for Ollama's ``format``
|
||||
parameter. Pydantic ``$defs`` are inlined so the schema is
|
||||
self-contained.
|
||||
"""
|
||||
raw = ExtractionResult.model_json_schema()
|
||||
# Inline $defs so the schema is flat and Ollama-friendly
|
||||
return _inline_defs(raw)
|
||||
|
||||
|
||||
def get_schema_version() -> str:
|
||||
"""Return the current schema version string."""
|
||||
return SCHEMA_VERSION
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ValidationReport(BaseModel):
|
||||
"""Result of validating a raw model response."""
|
||||
|
||||
valid: bool = False
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
parsed: ExtractionResult | None = None
|
||||
|
||||
|
||||
def validate_extraction(
|
||||
raw_json: str | dict[str, Any],
|
||||
*,
|
||||
document_text: str = "",
|
||||
) -> ValidationReport:
|
||||
"""Validate raw model output against the extraction schema.
|
||||
|
||||
Performs structural (JSON / Pydantic) validation followed by semantic
|
||||
checks that catch hallucination indicators, cross-field inconsistencies,
|
||||
and data-quality issues.
|
||||
|
||||
Args:
|
||||
raw_json: Either a JSON string or an already-parsed dict.
|
||||
document_text: Optional original document text used for evidence
|
||||
span verification.
|
||||
|
||||
Returns:
|
||||
A ``ValidationReport`` with parsed result on success.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
# --- Parse JSON string if needed ---
|
||||
if isinstance(raw_json, str):
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
|
||||
else:
|
||||
data = raw_json
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return ValidationReport(valid=False, errors=["Expected a JSON object at top level."])
|
||||
|
||||
# --- Pydantic structural validation ---
|
||||
try:
|
||||
result = ExtractionResult.model_validate(data)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"])
|
||||
|
||||
# --- Semantic checks ---
|
||||
sem_errors, sem_warnings = _semantic_checks(result, document_text)
|
||||
errors.extend(sem_errors)
|
||||
warnings.extend(sem_warnings)
|
||||
|
||||
# Semantic errors make the report invalid — the caller should retry.
|
||||
valid = len(errors) == 0
|
||||
return ValidationReport(
|
||||
valid=valid,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
parsed=result,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Known valid impact horizons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_IMPACT_HORIZONS = frozenset({
|
||||
"intraday",
|
||||
"1d",
|
||||
"1d_7d",
|
||||
"1d_30d",
|
||||
"30d_90d",
|
||||
"90d_plus",
|
||||
})
|
||||
|
||||
# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.)
|
||||
_TICKER_RE = re.compile(r"^[A-Z]{1,5}$")
|
||||
|
||||
# Evidence span length bounds (characters)
|
||||
_MIN_EVIDENCE_LEN = 8
|
||||
_MAX_EVIDENCE_LEN = 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Semantic validation rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _semantic_checks(
|
||||
result: ExtractionResult,
|
||||
document_text: str = "",
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Run semantic checks on a parsed extraction.
|
||||
|
||||
Returns a tuple of (errors, warnings). Errors are issues severe enough
|
||||
to warrant a retry; warnings are informational.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
# --- Top-level checks ---
|
||||
if not result.summary:
|
||||
warnings.append("empty_summary")
|
||||
|
||||
if result.confidence < 0.3 and len(result.companies) > 0:
|
||||
warnings.append("low_confidence_with_companies")
|
||||
|
||||
# Duplicate tickers across company entries
|
||||
tickers_seen: list[str] = []
|
||||
for comp in result.companies:
|
||||
if comp.ticker in tickers_seen:
|
||||
errors.append(f"duplicate_ticker_{comp.ticker}")
|
||||
tickers_seen.append(comp.ticker)
|
||||
|
||||
# --- Per-company checks ---
|
||||
for comp in result.companies:
|
||||
tag = comp.ticker or "unknown"
|
||||
|
||||
# Ticker format
|
||||
if not comp.ticker:
|
||||
errors.append("company_missing_ticker")
|
||||
elif not _TICKER_RE.match(comp.ticker):
|
||||
warnings.append(f"invalid_ticker_format_{tag}")
|
||||
|
||||
# Impact horizon must be a known value
|
||||
if comp.impact_horizon not in VALID_IMPACT_HORIZONS:
|
||||
errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}")
|
||||
|
||||
# Evidence spans
|
||||
if not comp.evidence_spans:
|
||||
warnings.append(f"no_evidence_spans_for_{tag}")
|
||||
else:
|
||||
for idx, span in enumerate(comp.evidence_spans):
|
||||
if len(span) < _MIN_EVIDENCE_LEN:
|
||||
warnings.append(f"evidence_span_too_short_for_{tag}_{idx}")
|
||||
if len(span) > _MAX_EVIDENCE_LEN:
|
||||
warnings.append(f"evidence_span_too_long_for_{tag}_{idx}")
|
||||
|
||||
# Cross-field: high impact but no facts
|
||||
if not comp.key_facts and comp.impact_score > 0.5:
|
||||
warnings.append(f"high_impact_no_facts_for_{tag}")
|
||||
|
||||
# Cross-field: very low relevance
|
||||
if comp.relevance < 0.2:
|
||||
warnings.append(f"very_low_relevance_for_{tag}")
|
||||
|
||||
# Cross-field: strong sentiment but low impact
|
||||
if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1:
|
||||
warnings.append(f"strong_sentiment_low_impact_for_{tag}")
|
||||
|
||||
# --- Evidence grounding check (when source text is available) ---
|
||||
if document_text:
|
||||
doc_lower = document_text.lower()
|
||||
for comp in result.companies:
|
||||
for idx, span in enumerate(comp.evidence_spans):
|
||||
if span.lower() not in doc_lower:
|
||||
warnings.append(
|
||||
f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}"
|
||||
)
|
||||
|
||||
return errors, warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained."""
|
||||
defs = schema.pop("$defs", {})
|
||||
return _resolve_refs(schema, defs)
|
||||
|
||||
|
||||
def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any:
|
||||
"""Walk the schema tree and replace ``$ref`` pointers with their definitions."""
|
||||
if isinstance(node, dict):
|
||||
if "$ref" in node:
|
||||
ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem"
|
||||
ref_name = ref_path.rsplit("/", 1)[-1]
|
||||
if ref_name in defs:
|
||||
resolved = defs[ref_name].copy()
|
||||
# The resolved def may itself contain refs
|
||||
return _resolve_refs(resolved, defs)
|
||||
return node # unresolvable ref, leave as-is
|
||||
return {k: _resolve_refs(v, defs) for k, v in node.items()}
|
||||
if isinstance(node, list):
|
||||
return [_resolve_refs(item, defs) for item in node]
|
||||
return node
|
||||
@@ -1 +1,291 @@
|
||||
"""Extraction worker - sends documents to Ollama for structured intelligence extraction."""
|
||||
"""Extraction worker - sends documents to Ollama for structured intelligence extraction.
|
||||
|
||||
Orchestrates the full extraction pipeline for a single document:
|
||||
1. Calls OllamaClient to get structured extraction
|
||||
2. Uploads prompts, raw outputs, and validation reports to MinIO
|
||||
3. Persists the final intelligence object and per-company impact records to PostgreSQL
|
||||
4. Updates document status
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.extractor.client import ExtractionResponse
|
||||
from services.extractor.metrics import collect_metrics, persist_metrics
|
||||
from services.shared.metadata import (
|
||||
persist_document_impact,
|
||||
persist_document_intelligence,
|
||||
update_document_status,
|
||||
)
|
||||
from services.shared.storage import (
|
||||
upload_extraction_intelligence,
|
||||
upload_extraction_prompt,
|
||||
upload_extraction_raw_output,
|
||||
upload_extraction_validation,
|
||||
)
|
||||
from services.shared.logging import Span
|
||||
from services.shared.metrics import (
|
||||
EXTRACTION_ATTEMPTS,
|
||||
EXTRACTION_CONFIDENCE,
|
||||
EXTRACTION_DURATION,
|
||||
EXTRACTION_JOBS_TOTAL,
|
||||
EXTRACTION_RETRIES,
|
||||
EXTRACTION_TOKEN_ESTIMATE,
|
||||
EXTRACTION_VALIDATION_ERRORS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("extractor_worker")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionPersistResult:
|
||||
"""Result of persisting an extraction to storage and database."""
|
||||
|
||||
intelligence_id: str | None = None
|
||||
prompt_ref: str | None = None
|
||||
raw_output_ref: str | None = None
|
||||
validation_ref: str | None = None
|
||||
intelligence_ref: str | None = None
|
||||
impact_ids: list[str] | None = None
|
||||
metrics_id: str | None = None
|
||||
success: bool = False
|
||||
|
||||
|
||||
async def persist_extraction(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
ticker: str,
|
||||
extraction_response: ExtractionResponse,
|
||||
company_id_map: dict[str, str] | None = None,
|
||||
source_credibility: float = 0.5,
|
||||
timestamp: datetime | None = None,
|
||||
document_text_length: int = 0,
|
||||
) -> ExtractionPersistResult:
|
||||
"""Persist all extraction artifacts to MinIO and PostgreSQL.
|
||||
|
||||
Uploads prompts, raw model outputs, validation reports, and the final
|
||||
intelligence object to MinIO. Persists the intelligence record and
|
||||
per-company impact records to PostgreSQL. Updates document status.
|
||||
Also collects and persists model performance metrics.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
minio_client: MinIO client.
|
||||
document_id: UUID of the source document.
|
||||
ticker: Primary ticker for path construction.
|
||||
extraction_response: Full response from OllamaClient.extract().
|
||||
company_id_map: Optional mapping of ticker -> company UUID for impact records.
|
||||
source_credibility: Credibility score to attach to the intelligence record.
|
||||
timestamp: Override timestamp for MinIO paths (defaults to UTC now).
|
||||
document_text_length: Length of the input document text for token estimation.
|
||||
|
||||
Returns:
|
||||
ExtractionPersistResult with references to all persisted artifacts.
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
result = ExtractionPersistResult()
|
||||
company_id_map = company_id_map or {}
|
||||
|
||||
# 1. Upload prompt metadata to MinIO
|
||||
prompt_payload = json.dumps({
|
||||
"prompt_metadata": extraction_response.prompt_metadata,
|
||||
"model": extraction_response.model,
|
||||
}, indent=2).encode()
|
||||
result.prompt_ref = upload_extraction_prompt(
|
||||
minio_client, ticker, document_id, prompt_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# 2. Upload raw outputs for each attempt
|
||||
attempts_data: list[dict[str, object]] = []
|
||||
for idx, attempt in enumerate(extraction_response.attempts):
|
||||
attempt_record: dict[str, object] = {
|
||||
"attempt_index": idx,
|
||||
"raw_output": attempt.raw_output,
|
||||
"error": attempt.error,
|
||||
"duration_ms": attempt.duration_ms,
|
||||
"model": attempt.model,
|
||||
"retryable": attempt.retryable,
|
||||
}
|
||||
if attempt.validation:
|
||||
attempt_record["validation"] = {
|
||||
"valid": attempt.validation.valid,
|
||||
"errors": attempt.validation.errors,
|
||||
"warnings": attempt.validation.warnings,
|
||||
}
|
||||
attempts_data.append(attempt_record)
|
||||
|
||||
raw_output_payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"attempts": attempts_data,
|
||||
"total_duration_ms": extraction_response.total_duration_ms,
|
||||
"success": extraction_response.success,
|
||||
}, indent=2).encode()
|
||||
result.raw_output_ref = upload_extraction_raw_output(
|
||||
minio_client, ticker, document_id, raw_output_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# 3. Upload validation report
|
||||
final_attempt = extraction_response.attempts[-1] if extraction_response.attempts else None
|
||||
validation_payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"success": extraction_response.success,
|
||||
"attempt_count": len(extraction_response.attempts),
|
||||
"final_validation": {
|
||||
"valid": final_attempt.validation.valid if final_attempt and final_attempt.validation else False,
|
||||
"errors": final_attempt.validation.errors if final_attempt and final_attempt.validation else [],
|
||||
"warnings": final_attempt.validation.warnings if final_attempt and final_attempt.validation else [],
|
||||
} if final_attempt else None,
|
||||
}, indent=2).encode()
|
||||
result.validation_ref = upload_extraction_validation(
|
||||
minio_client, ticker, document_id, validation_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# 4. Determine validation status and persist intelligence
|
||||
if extraction_response.success and extraction_response.result:
|
||||
extraction = extraction_response.result
|
||||
validation_status = "valid"
|
||||
validation_errors: list[str] = []
|
||||
|
||||
# Upload final intelligence object to MinIO
|
||||
intelligence_payload = json.dumps(
|
||||
extraction.model_dump(mode="json"), indent=2,
|
||||
).encode()
|
||||
result.intelligence_ref = upload_extraction_intelligence(
|
||||
minio_client, ticker, document_id, intelligence_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# Persist to PostgreSQL
|
||||
intel_id = await persist_document_intelligence(
|
||||
pool,
|
||||
document_id=document_id,
|
||||
summary=extraction.summary,
|
||||
macro_themes=extraction.macro_themes,
|
||||
novelty_score=extraction.novelty_score,
|
||||
source_credibility=source_credibility,
|
||||
extraction_warnings=extraction.extraction_warnings,
|
||||
confidence=extraction.confidence,
|
||||
model_provider="ollama",
|
||||
model_name=extraction_response.model,
|
||||
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
||||
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
||||
raw_output_ref=result.raw_output_ref,
|
||||
prompt_ref=result.prompt_ref,
|
||||
validation_status=validation_status,
|
||||
validation_errors=validation_errors,
|
||||
retry_count=len(extraction_response.attempts) - 1,
|
||||
)
|
||||
result.intelligence_id = intel_id
|
||||
|
||||
# Persist per-company impact records
|
||||
result.impact_ids = []
|
||||
for company in extraction.companies:
|
||||
cid = company_id_map.get(company.ticker)
|
||||
if not cid:
|
||||
logger.warning(
|
||||
"No company_id for ticker %s in doc %s, skipping impact record",
|
||||
company.ticker, document_id,
|
||||
)
|
||||
continue
|
||||
impact_id = await persist_document_impact(
|
||||
pool,
|
||||
intelligence_id=intel_id,
|
||||
company_id=cid,
|
||||
ticker=company.ticker,
|
||||
relevance=company.relevance,
|
||||
sentiment=company.sentiment,
|
||||
impact_score=company.impact_score,
|
||||
impact_horizon=company.impact_horizon,
|
||||
catalyst_type=company.catalyst_type,
|
||||
key_facts=company.key_facts,
|
||||
risks=company.risks,
|
||||
evidence_spans=company.evidence_spans,
|
||||
)
|
||||
result.impact_ids.append(impact_id)
|
||||
|
||||
await update_document_status(pool, document_id=document_id, status="extracted")
|
||||
result.success = True
|
||||
logger.info(
|
||||
"Extraction persisted for doc %s: intel=%s, impacts=%d",
|
||||
document_id, intel_id, len(result.impact_ids),
|
||||
)
|
||||
else:
|
||||
# Failed extraction — still persist the attempt data
|
||||
all_errors: list[str] = []
|
||||
for attempt in extraction_response.attempts:
|
||||
if attempt.error:
|
||||
all_errors.append(attempt.error)
|
||||
|
||||
intel_id = await persist_document_intelligence(
|
||||
pool,
|
||||
document_id=document_id,
|
||||
summary="",
|
||||
macro_themes=[],
|
||||
novelty_score=0.0,
|
||||
source_credibility=source_credibility,
|
||||
extraction_warnings=["extraction_failed"],
|
||||
confidence=0.0,
|
||||
model_provider="ollama",
|
||||
model_name=extraction_response.model,
|
||||
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
||||
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
||||
raw_output_ref=result.raw_output_ref,
|
||||
prompt_ref=result.prompt_ref,
|
||||
validation_status="failed",
|
||||
validation_errors=all_errors,
|
||||
retry_count=len(extraction_response.attempts),
|
||||
)
|
||||
result.intelligence_id = intel_id
|
||||
|
||||
await update_document_status(pool, document_id=document_id, status="extraction_failed")
|
||||
logger.warning(
|
||||
"Extraction failed for doc %s after %d attempts: %s",
|
||||
document_id, len(extraction_response.attempts), "; ".join(all_errors),
|
||||
)
|
||||
|
||||
# Collect and persist model performance metrics
|
||||
try:
|
||||
metrics = collect_metrics(
|
||||
extraction_response,
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
document_text_length=document_text_length,
|
||||
)
|
||||
metrics.recorded_at = ts
|
||||
metrics_id = await persist_metrics(pool, metrics)
|
||||
result.metrics_id = metrics_id
|
||||
except Exception:
|
||||
logger.exception("Failed to persist extraction metrics for doc %s", document_id)
|
||||
|
||||
# Prometheus metrics
|
||||
EXTRACTION_ATTEMPTS.inc(len(extraction_response.attempts))
|
||||
EXTRACTION_DURATION.observe(extraction_response.total_duration_ms / 1000.0)
|
||||
retry_count = max(0, len(extraction_response.attempts) - 1)
|
||||
if retry_count > 0:
|
||||
EXTRACTION_RETRIES.inc(retry_count)
|
||||
if extraction_response.success:
|
||||
EXTRACTION_JOBS_TOTAL.labels(status="success").inc()
|
||||
if extraction_response.result:
|
||||
EXTRACTION_CONFIDENCE.observe(extraction_response.result.confidence)
|
||||
else:
|
||||
EXTRACTION_JOBS_TOTAL.labels(status="failed").inc()
|
||||
# Count validation errors from final attempt
|
||||
final = extraction_response.attempts[-1] if extraction_response.attempts else None
|
||||
if final and final.validation and final.validation.errors:
|
||||
EXTRACTION_VALIDATION_ERRORS.inc(len(final.validation.errors))
|
||||
# Token estimates
|
||||
if document_text_length > 0:
|
||||
EXTRACTION_TOKEN_ESTIMATE.labels(direction="input").inc(document_text_length // 4)
|
||||
if final and final.raw_output:
|
||||
EXTRACTION_TOKEN_ESTIMATE.labels(direction="output").inc(len(final.raw_output) // 4)
|
||||
|
||||
return result
|
||||
|
||||
+151
-80
@@ -1,47 +1,50 @@
|
||||
"""Ingestion worker - processes jobs from the ingestion queue."""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
from minio import Minio
|
||||
|
||||
from services.adapters.base import AdapterResult
|
||||
from services.adapters.filings_adapter import FilingsAdapter
|
||||
from services.adapters.market_adapter import MarketDataAdapter
|
||||
from services.adapters.news_adapter import NewsApiAdapter
|
||||
from services.adapters.broker_adapter import AlpacaBrokerAdapter, TradingMode
|
||||
from services.adapters.filings_adapter import SECEdgarAdapter
|
||||
from services.adapters.market_adapter import PolygonMarketAdapter
|
||||
from services.adapters.news_adapter import PolygonNewsAdapter
|
||||
from services.adapters.web_scrape_adapter import WebScrapeAdapter
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_minio, get_pg_pool, get_redis
|
||||
from services.shared.dedupe import dedupe_items, mark_as_seen
|
||||
from services.shared.metadata import (
|
||||
persist_ingestion_items,
|
||||
record_retrieval_failure,
|
||||
reset_source_retry_state,
|
||||
)
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_INGESTION,
|
||||
QUEUE_PARSING,
|
||||
dedupe_key,
|
||||
queue_key,
|
||||
)
|
||||
from services.shared.logging import Span, extract_trace_context, inject_trace_context, new_trace_id, set_trace_context, setup_logging
|
||||
from services.shared.metrics import (
|
||||
ACTIVE_JOBS,
|
||||
INGESTION_ADAPTER_DURATION,
|
||||
INGESTION_ERRORS,
|
||||
INGESTION_ITEMS_DEDUPED,
|
||||
INGESTION_ITEMS_FETCHED,
|
||||
INGESTION_ITEMS_NEW,
|
||||
INGESTION_JOBS_TOTAL,
|
||||
)
|
||||
from services.shared.storage import (
|
||||
bucket_for_source,
|
||||
ensure_buckets,
|
||||
upload_raw_artifact,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("ingestion_worker")
|
||||
|
||||
BUCKET_MAP = {
|
||||
"market_api": "stonks-raw-market",
|
||||
"news_api": "stonks-raw-news",
|
||||
"filings_api": "stonks-raw-filings",
|
||||
"broker": "stonks-raw-market",
|
||||
}
|
||||
|
||||
|
||||
def build_storage_path(source_type: str, ticker: str, doc_id: str) -> str:
|
||||
now = datetime.utcnow()
|
||||
return f"{source_type}/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/raw.json"
|
||||
|
||||
|
||||
async def store_raw_artifact(minio_client: Minio, bucket: str, path: str, data: bytes):
|
||||
minio_client.put_object(bucket, path, io.BytesIO(data), len(data), content_type="application/json")
|
||||
|
||||
|
||||
async def process_job(
|
||||
job: dict,
|
||||
@@ -55,9 +58,11 @@ async def process_job(
|
||||
source_id = job["source_id"]
|
||||
config = job.get("config", {})
|
||||
|
||||
set_trace_context(trace_id=job.get("_trace_id") or new_trace_id())
|
||||
|
||||
adapter = adapters.get(source_type)
|
||||
if not adapter:
|
||||
logger.warning(f"No adapter for source_type={source_type}")
|
||||
logger.warning("No adapter for source_type=%s", source_type)
|
||||
return
|
||||
|
||||
# Record ingestion run
|
||||
@@ -68,25 +73,37 @@ async def process_job(
|
||||
)
|
||||
|
||||
try:
|
||||
result: AdapterResult = await adapter.fetch(ticker, config)
|
||||
with Span("adapter_fetch", ticker=ticker, source_type=source_type):
|
||||
with INGESTION_ADAPTER_DURATION.labels(source_type=source_type).time():
|
||||
result: AdapterResult = await adapter.fetch(ticker, config)
|
||||
|
||||
if result.error:
|
||||
await pool.execute(
|
||||
"UPDATE ingestion_runs SET status='failed', error_message=$2, completed_at=NOW() WHERE id=$1",
|
||||
run_id, result.error,
|
||||
INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="error").inc()
|
||||
await record_retrieval_failure(
|
||||
pool,
|
||||
run_id=str(run_id),
|
||||
source_id=source_id,
|
||||
error_message=result.error,
|
||||
)
|
||||
return
|
||||
|
||||
# Store raw payload
|
||||
bucket = BUCKET_MAP.get(source_type, "stonks-raw-market")
|
||||
storage_path = build_storage_path(source_type, ticker, str(run_id))
|
||||
await store_raw_artifact(minio_client, bucket, storage_path, result.raw_payload)
|
||||
# Store raw payload in MinIO
|
||||
bucket = bucket_for_source(source_type)
|
||||
artifact_type = "raw_html" if source_type == "web_scrape" else "raw_json"
|
||||
storage_uri = upload_raw_artifact(
|
||||
minio_client,
|
||||
source_type=source_type,
|
||||
ticker=ticker,
|
||||
document_id=str(run_id),
|
||||
data=result.raw_payload,
|
||||
artifact_type=artifact_type,
|
||||
)
|
||||
|
||||
# Dedupe check
|
||||
# Dedupe check on the overall payload hash
|
||||
if result.content_hash:
|
||||
already_seen = await rds.get(dedupe_key(result.content_hash))
|
||||
if already_seen:
|
||||
logger.info(f"Duplicate content for {ticker}, skipping")
|
||||
logger.info("Duplicate content for %s, skipping", ticker)
|
||||
await pool.execute(
|
||||
"UPDATE ingestion_runs SET status='completed', items_fetched=$2, items_new=0, completed_at=NOW() WHERE id=$1",
|
||||
run_id, len(result.items),
|
||||
@@ -94,72 +111,126 @@ async def process_job(
|
||||
return
|
||||
await rds.set(dedupe_key(result.content_hash), "1", ex=86400)
|
||||
|
||||
new_items = 0
|
||||
for item in result.items:
|
||||
item_json = json.dumps(item)
|
||||
item_hash = hashlib.sha256(item_json.encode()).hexdigest()
|
||||
# Cross-source dedupe on individual document items (news, filings, web_scrape)
|
||||
items_to_persist = result.items
|
||||
deduped_count = 0
|
||||
if source_type not in ("market_api", "broker"):
|
||||
items_to_persist, dup_items = await dedupe_items(pool, rds, result.items)
|
||||
deduped_count = len(dup_items)
|
||||
if deduped_count:
|
||||
INGESTION_ITEMS_DEDUPED.labels(source_type=source_type).inc(deduped_count)
|
||||
logger.info(
|
||||
"Deduped %d/%d items for %s/%s",
|
||||
deduped_count, len(result.items), ticker, source_type,
|
||||
)
|
||||
|
||||
# Check if document already exists
|
||||
exists = await pool.fetchval("SELECT 1 FROM documents WHERE content_hash = $1", item_hash)
|
||||
if exists:
|
||||
continue
|
||||
# Persist metadata via the unified metadata module
|
||||
new_items, new_ids = await persist_ingestion_items(
|
||||
pool,
|
||||
source_type=source_type,
|
||||
ticker=ticker,
|
||||
company_id=job.get("company_id"),
|
||||
items=items_to_persist,
|
||||
storage_ref=storage_uri,
|
||||
adapter_metadata=result.metadata,
|
||||
content_hash=result.content_hash,
|
||||
)
|
||||
|
||||
title = item.get("title", item.get("name", ""))
|
||||
url = item.get("url", item.get("link", ""))
|
||||
published = item.get("publishedAt", item.get("published_at"))
|
||||
# Enqueue new document items for parsing (not market/broker)
|
||||
if source_type not in ("market_api", "broker"):
|
||||
for doc_id in new_ids:
|
||||
await rds.rpush(queue_key(QUEUE_PARSING), json.dumps(inject_trace_context({
|
||||
"document_id": doc_id,
|
||||
"ticker": ticker,
|
||||
"source_type": source_type,
|
||||
})))
|
||||
|
||||
doc_id = await pool.fetchval(
|
||||
"""INSERT INTO documents (document_type, source_type, publisher, url, title, published_at, content_hash, raw_storage_ref, status)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'ingested')
|
||||
RETURNING id""",
|
||||
"article" if source_type == "news_api" else "filing" if source_type == "filings_api" else "article",
|
||||
source_type,
|
||||
item.get("source", {}).get("name", "") if isinstance(item.get("source"), dict) else str(item.get("source", "")),
|
||||
url, title,
|
||||
datetime.fromisoformat(published.replace("Z", "+00:00")) if published else None,
|
||||
item_hash,
|
||||
f"s3://{bucket}/{storage_path}",
|
||||
)
|
||||
# Mark newly persisted documents in Redis for fast future dedupe
|
||||
for item, doc_id in zip(items_to_persist, new_ids):
|
||||
await mark_as_seen(
|
||||
rds,
|
||||
content_hash=item.get("content_hash", ""),
|
||||
canonical_url=item.get("canonical_url"),
|
||||
document_id=doc_id,
|
||||
)
|
||||
|
||||
# Enqueue for parsing
|
||||
await rds.rpush(queue_key(QUEUE_PARSING), json.dumps({
|
||||
"document_id": str(doc_id),
|
||||
"ticker": ticker,
|
||||
"source_type": source_type,
|
||||
"url": url,
|
||||
}))
|
||||
new_items += 1
|
||||
# Link duplicate documents to this company if not already linked
|
||||
company_id = job.get("company_id")
|
||||
if company_id and deduped_count:
|
||||
from services.shared.metadata import persist_document_company_mention
|
||||
for dup in dup_items:
|
||||
existing_id = dup.get("_dedupe_existing_id")
|
||||
if existing_id:
|
||||
try:
|
||||
await persist_document_company_mention(
|
||||
pool,
|
||||
document_id=existing_id,
|
||||
company_id=company_id,
|
||||
ticker=ticker,
|
||||
mention_type="cross_source",
|
||||
)
|
||||
except Exception:
|
||||
# Duplicate mention link — safe to ignore
|
||||
pass
|
||||
|
||||
await pool.execute(
|
||||
"UPDATE ingestion_runs SET status='completed', items_fetched=$2, items_new=$3, completed_at=NOW() WHERE id=$1",
|
||||
run_id, len(result.items), new_items,
|
||||
)
|
||||
logger.info(f"Ingested {ticker}/{source_type}: {len(result.items)} fetched, {new_items} new")
|
||||
# Clear any accumulated retry backoff after success
|
||||
await reset_source_retry_state(pool, source_id)
|
||||
INGESTION_ITEMS_FETCHED.labels(source_type=source_type).inc(len(result.items))
|
||||
INGESTION_ITEMS_NEW.labels(source_type=source_type).inc(new_items)
|
||||
INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="success").inc()
|
||||
logger.info(
|
||||
"Ingested %s/%s: %d fetched, %d new",
|
||||
ticker, source_type, len(result.items), new_items,
|
||||
extra={"ticker": ticker, "source_type": source_type, "count": new_items},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ingestion error for {ticker}: {e}")
|
||||
await pool.execute(
|
||||
"UPDATE ingestion_runs SET status='failed', error_message=$2, completed_at=NOW() WHERE id=$1",
|
||||
run_id, str(e),
|
||||
INGESTION_ERRORS.labels(source_type=source_type).inc()
|
||||
INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="error").inc()
|
||||
logger.error(
|
||||
"Ingestion error for %s: %s", ticker, e,
|
||||
extra={"ticker": ticker, "source_type": source_type, "error": str(e)},
|
||||
)
|
||||
await record_retrieval_failure(
|
||||
pool,
|
||||
run_id=str(run_id),
|
||||
source_id=source_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
config = load_config()
|
||||
pool = await get_pg_pool(config)
|
||||
rds = get_redis(config)
|
||||
minio_client = get_minio(config)
|
||||
cfg = load_config()
|
||||
setup_logging("ingestion_worker", level=cfg.log_level, json_output=cfg.json_logs)
|
||||
|
||||
pool = await get_pg_pool(cfg)
|
||||
rds = get_redis(cfg)
|
||||
minio_client = get_minio(cfg)
|
||||
|
||||
# Ensure all required buckets exist
|
||||
ensure_buckets(minio_client)
|
||||
|
||||
adapters = {
|
||||
"market_api": MarketDataAdapter(
|
||||
api_key=config.broker.api_key or "",
|
||||
"market_api": PolygonMarketAdapter(
|
||||
api_key=cfg.market_data.api_key,
|
||||
base_url=cfg.market_data.base_url,
|
||||
),
|
||||
"news_api": PolygonNewsAdapter(
|
||||
api_key=cfg.market_data.api_key,
|
||||
base_url="https://api.polygon.io",
|
||||
),
|
||||
"news_api": NewsApiAdapter(
|
||||
api_key="",
|
||||
base_url="https://newsapi.org",
|
||||
"filings_api": SECEdgarAdapter(),
|
||||
"web_scrape": WebScrapeAdapter(),
|
||||
"broker": AlpacaBrokerAdapter(
|
||||
api_key=cfg.broker.api_key or "",
|
||||
api_secret=cfg.broker.api_secret or "",
|
||||
mode=TradingMode.LIVE if cfg.broker.mode == "live" else TradingMode.PAPER,
|
||||
base_url=cfg.broker.base_url,
|
||||
),
|
||||
"filings_api": FilingsAdapter(),
|
||||
}
|
||||
|
||||
logger.info("Ingestion worker started")
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Lake Publisher - transforms operational data into analytical fact datasets
|
||||
"""Lake publisher — writes partitioned Parquet facts to MinIO for Trino/Superset."""
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Helpers for enqueuing lake publish jobs from upstream workers.
|
||||
|
||||
Other services import these helpers to push jobs onto the QUEUE_LAKE_PUBLISH
|
||||
Redis queue. The lake publisher worker (jobs.py) consumes them.
|
||||
|
||||
Usage:
|
||||
await enqueue_lake_job(rds, "document", document_id)
|
||||
await enqueue_lake_job(rds, "trade_order", order_id)
|
||||
await enqueue_lake_job(rds, "bulk_documents", since=cutoff.isoformat())
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.shared.redis_keys import QUEUE_LAKE_PUBLISH, queue_key
|
||||
|
||||
|
||||
async def enqueue_lake_job(
|
||||
rds: aioredis.Redis,
|
||||
job_type: str,
|
||||
entity_id: str = "",
|
||||
since: str | None = None,
|
||||
) -> None:
|
||||
"""Push a lake publish job onto the Redis queue.
|
||||
|
||||
Args:
|
||||
rds: Async Redis client.
|
||||
job_type: One of the supported job types (document, document_extraction,
|
||||
market_snapshot, trade_order, trade_fill, positions_snapshot,
|
||||
pnl_snapshot, bulk_documents, bulk_extractions).
|
||||
entity_id: UUID or identifier for the entity to publish.
|
||||
since: ISO datetime string for bulk jobs (cutoff timestamp).
|
||||
"""
|
||||
payload: dict[str, str] = {"job_type": job_type, "entity_id": entity_id}
|
||||
if since:
|
||||
payload["since"] = since
|
||||
await rds.rpush(queue_key(QUEUE_LAKE_PUBLISH), json.dumps(payload)) # type: ignore[misc]
|
||||
@@ -0,0 +1,420 @@
|
||||
"""Iceberg table creation and metadata management for analytical datasets.
|
||||
|
||||
Manages Iceberg tables in Trino's Iceberg catalog, providing:
|
||||
- Table creation with proper schemas and partition specs
|
||||
- Schema synchronization between PyArrow definitions and Iceberg tables
|
||||
- Table metadata inspection (existence checks, schema retrieval, partition listing)
|
||||
|
||||
The Iceberg catalog complements the existing Hive-compatible partition layout.
|
||||
Parquet files written by the lake publisher are stored in the same MinIO paths,
|
||||
but Iceberg metadata enables schema evolution, snapshot isolation, and better
|
||||
partition pruning via Trino's Iceberg connector.
|
||||
|
||||
Requirements: 9.4, 9.5, 10.1, N4, N6
|
||||
Design ref: Section 5.3 (Lakehouse model), Section 4.12 (SQL Query Engine)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
from trino.dbapi import connect as trino_connect
|
||||
|
||||
from services.lake_publisher.partitions import (
|
||||
LAKEHOUSE_BUCKET,
|
||||
TABLE_PARTITIONS,
|
||||
WAREHOUSE_PREFIX,
|
||||
PartitionSpec,
|
||||
)
|
||||
from services.lake_publisher.worker import (
|
||||
COMPANY_EVENTS_SCHEMA,
|
||||
DOCUMENTS_SCHEMA,
|
||||
DOCUMENT_EXTRACTIONS_SCHEMA,
|
||||
MARKET_BARS_SCHEMA,
|
||||
MARKET_QUOTES_SCHEMA,
|
||||
MODEL_PERFORMANCE_SCHEMA,
|
||||
PNL_DAILY_SCHEMA,
|
||||
POSITIONS_DAILY_SCHEMA,
|
||||
PREDICTION_VS_OUTCOME_SCHEMA,
|
||||
TRADE_FILLS_SCHEMA,
|
||||
TRADE_ORDERS_SCHEMA,
|
||||
TRADE_SIGNALS_SCHEMA,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ICEBERG_CATALOG = "iceberg"
|
||||
ICEBERG_SCHEMA = "stonks"
|
||||
|
||||
|
||||
def _get_iceberg_catalog() -> str:
|
||||
"""Return the Iceberg catalog name from env or default."""
|
||||
import os
|
||||
return os.getenv("TRINO_ICEBERG_CATALOG", ICEBERG_CATALOG)
|
||||
|
||||
# Map PyArrow types to Trino/Iceberg SQL types.
|
||||
_ARROW_TO_TRINO: dict[str, str] = {
|
||||
"string": "VARCHAR",
|
||||
"utf8": "VARCHAR",
|
||||
"large_string": "VARCHAR",
|
||||
"large_utf8": "VARCHAR",
|
||||
"float64": "DOUBLE",
|
||||
"double": "DOUBLE",
|
||||
"float32": "REAL",
|
||||
"float": "REAL",
|
||||
"int8": "TINYINT",
|
||||
"int16": "SMALLINT",
|
||||
"int32": "INTEGER",
|
||||
"int64": "BIGINT",
|
||||
"bool": "BOOLEAN",
|
||||
"date32": "DATE",
|
||||
"date32[day]": "DATE",
|
||||
"date64": "DATE",
|
||||
}
|
||||
|
||||
|
||||
def _arrow_type_to_trino(arrow_type: pa.DataType) -> str:
|
||||
"""Convert a PyArrow data type to a Trino SQL type string."""
|
||||
type_str = str(arrow_type)
|
||||
|
||||
# Handle timestamp types (with or without timezone)
|
||||
if type_str.startswith("timestamp"):
|
||||
if "tz=" in type_str:
|
||||
return "TIMESTAMP(6) WITH TIME ZONE"
|
||||
return "TIMESTAMP(6)"
|
||||
|
||||
# Direct lookup
|
||||
result = _ARROW_TO_TRINO.get(type_str)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Fallback for type IDs
|
||||
if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type):
|
||||
return "VARCHAR"
|
||||
if pa.types.is_floating(arrow_type):
|
||||
return "DOUBLE"
|
||||
if pa.types.is_integer(arrow_type):
|
||||
return "BIGINT"
|
||||
if pa.types.is_boolean(arrow_type):
|
||||
return "BOOLEAN"
|
||||
if pa.types.is_date(arrow_type):
|
||||
return "DATE"
|
||||
if pa.types.is_timestamp(arrow_type):
|
||||
return "TIMESTAMP(6) WITH TIME ZONE"
|
||||
|
||||
raise ValueError(f"Unsupported PyArrow type for Iceberg DDL: {arrow_type}")
|
||||
|
||||
|
||||
|
||||
# Registry mapping table names to their PyArrow schemas.
|
||||
TABLE_SCHEMAS: dict[str, pa.Schema] = {
|
||||
"market_bars": MARKET_BARS_SCHEMA,
|
||||
"market_quotes": MARKET_QUOTES_SCHEMA,
|
||||
"company_events": COMPANY_EVENTS_SCHEMA,
|
||||
"documents": DOCUMENTS_SCHEMA,
|
||||
"document_extractions": DOCUMENT_EXTRACTIONS_SCHEMA,
|
||||
"trade_signals": TRADE_SIGNALS_SCHEMA,
|
||||
"trade_orders": TRADE_ORDERS_SCHEMA,
|
||||
"trade_fills": TRADE_FILLS_SCHEMA,
|
||||
"positions_daily": POSITIONS_DAILY_SCHEMA,
|
||||
"pnl_daily": PNL_DAILY_SCHEMA,
|
||||
"prediction_vs_outcome": PREDICTION_VS_OUTCOME_SCHEMA,
|
||||
"model_performance": MODEL_PERFORMANCE_SCHEMA,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IcebergTableDef:
|
||||
"""Definition for an Iceberg table derived from PyArrow schema + partition spec."""
|
||||
|
||||
table_name: str
|
||||
schema: pa.Schema
|
||||
partition_spec: PartitionSpec
|
||||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
return f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.{self.table_name}"
|
||||
|
||||
@property
|
||||
def location(self) -> str:
|
||||
return f"s3a://{LAKEHOUSE_BUCKET}/{WAREHOUSE_PREFIX}/{self.table_name}/"
|
||||
|
||||
def column_defs_sql(self) -> list[str]:
|
||||
"""Generate SQL column definitions from the PyArrow schema.
|
||||
|
||||
Partition columns are included in the column list (Iceberg stores them
|
||||
in the data files, unlike Hive external tables).
|
||||
"""
|
||||
cols: list[str] = []
|
||||
for i in range(len(self.schema)):
|
||||
name = self.schema.field(i).name
|
||||
arrow_type = self.schema.field(i).type
|
||||
trino_type = _arrow_type_to_trino(arrow_type)
|
||||
cols.append(f" {name} {trino_type}")
|
||||
return cols
|
||||
|
||||
def partition_keys_sql(self) -> str:
|
||||
"""Generate the partitioning clause for CREATE TABLE."""
|
||||
keys = list(self.partition_spec.all_keys)
|
||||
if not keys:
|
||||
return ""
|
||||
quoted = ", ".join(f"'{k}'" for k in keys)
|
||||
return f"partitioning = ARRAY[{quoted}]"
|
||||
|
||||
def create_table_sql(self) -> str:
|
||||
"""Generate a CREATE TABLE IF NOT EXISTS statement for Trino's Iceberg catalog."""
|
||||
col_lines = ",\n".join(self.column_defs_sql())
|
||||
with_clauses = [
|
||||
"format = 'PARQUET'",
|
||||
f"location = '{self.location}'",
|
||||
]
|
||||
part_sql = self.partition_keys_sql()
|
||||
if part_sql:
|
||||
with_clauses.append(part_sql)
|
||||
|
||||
with_block = ",\n ".join(with_clauses)
|
||||
|
||||
return (
|
||||
f"CREATE TABLE IF NOT EXISTS {self.qualified_name} (\n"
|
||||
f"{col_lines}\n"
|
||||
f") WITH (\n"
|
||||
f" {with_block}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
def get_all_table_defs() -> list[IcebergTableDef]:
|
||||
"""Build IcebergTableDef for every registered analytical table."""
|
||||
defs: list[IcebergTableDef] = []
|
||||
for table_name, partition_spec in TABLE_PARTITIONS.items():
|
||||
schema = TABLE_SCHEMAS.get(table_name)
|
||||
if schema is None:
|
||||
logger.warning("No PyArrow schema for table %s, skipping", table_name)
|
||||
continue
|
||||
defs.append(IcebergTableDef(
|
||||
table_name=table_name,
|
||||
schema=schema,
|
||||
partition_spec=partition_spec,
|
||||
))
|
||||
return defs
|
||||
|
||||
|
||||
def get_table_def(table_name: str) -> IcebergTableDef:
|
||||
"""Get the IcebergTableDef for a single table by name."""
|
||||
if table_name not in TABLE_PARTITIONS:
|
||||
raise ValueError(f"Unknown table: {table_name}")
|
||||
schema = TABLE_SCHEMAS.get(table_name)
|
||||
if schema is None:
|
||||
raise ValueError(f"No PyArrow schema registered for table: {table_name}")
|
||||
return IcebergTableDef(
|
||||
table_name=table_name,
|
||||
schema=schema,
|
||||
partition_spec=TABLE_PARTITIONS[table_name],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class IcebergManager:
|
||||
"""Manages Iceberg tables via Trino's Iceberg catalog.
|
||||
|
||||
Provides table creation, existence checks, schema inspection,
|
||||
and metadata operations against the Trino Iceberg connector.
|
||||
"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 8080
|
||||
user: str = "stonks"
|
||||
catalog: str = ICEBERG_CATALOG
|
||||
schema: str = ICEBERG_SCHEMA
|
||||
|
||||
def _get_connection(self) -> Any:
|
||||
"""Create a Trino DBAPI connection."""
|
||||
return trino_connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
catalog=self.catalog,
|
||||
schema=self.schema,
|
||||
)
|
||||
|
||||
def _execute(self, sql: str) -> list[list[Any]]:
|
||||
"""Execute a SQL statement and return all rows."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
return cursor.fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _execute_no_fetch(self, sql: str) -> None:
|
||||
"""Execute a DDL statement that returns no rows."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql)
|
||||
# DDL statements in Trino still need fetchall to complete
|
||||
try:
|
||||
cursor.fetchall()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def ensure_schema(self) -> None:
|
||||
"""Create the Iceberg schema if it doesn't exist."""
|
||||
sql = f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.schema}"
|
||||
logger.info("Ensuring Iceberg schema: %s.%s", self.catalog, self.schema)
|
||||
self._execute_no_fetch(sql)
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""Check if an Iceberg table exists."""
|
||||
sql = (
|
||||
f"SELECT table_name FROM {self.catalog}.information_schema.tables "
|
||||
f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}'"
|
||||
)
|
||||
rows = self._execute(sql)
|
||||
return len(rows) > 0
|
||||
|
||||
def create_table(self, table_name: str) -> bool:
|
||||
"""Create a single Iceberg table if it doesn't exist.
|
||||
|
||||
Returns True if the table was created, False if it already existed.
|
||||
"""
|
||||
table_def = get_table_def(table_name)
|
||||
ddl = table_def.create_table_sql()
|
||||
logger.info("Creating Iceberg table: %s", table_def.qualified_name)
|
||||
self._execute_no_fetch(ddl)
|
||||
logger.info("Iceberg table ready: %s", table_def.qualified_name)
|
||||
return True
|
||||
|
||||
def create_all_tables(self) -> dict[str, bool]:
|
||||
"""Create all registered Iceberg tables.
|
||||
|
||||
Returns a dict mapping table_name -> True (created) or False (error).
|
||||
"""
|
||||
self.ensure_schema()
|
||||
results: dict[str, bool] = {}
|
||||
for table_def in get_all_table_defs():
|
||||
try:
|
||||
self.create_table(table_def.table_name)
|
||||
results[table_def.table_name] = True
|
||||
except Exception:
|
||||
logger.exception("Failed to create Iceberg table: %s", table_def.table_name)
|
||||
results[table_def.table_name] = False
|
||||
return results
|
||||
|
||||
def get_table_schema(self, table_name: str) -> list[dict[str, str]]:
|
||||
"""Retrieve the column schema of an Iceberg table from Trino.
|
||||
|
||||
Returns a list of dicts with 'column_name', 'data_type', and 'is_nullable'.
|
||||
"""
|
||||
sql = (
|
||||
f"SELECT column_name, data_type, is_nullable "
|
||||
f"FROM {self.catalog}.information_schema.columns "
|
||||
f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}' "
|
||||
f"ORDER BY ordinal_position"
|
||||
)
|
||||
rows = self._execute(sql)
|
||||
return [
|
||||
{"column_name": r[0], "data_type": r[1], "is_nullable": r[2]}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def get_table_snapshots(self, table_name: str) -> list[dict[str, Any]]:
|
||||
"""List Iceberg snapshots for a table (useful for auditing and rollback).
|
||||
|
||||
Returns snapshot metadata from Trino's $snapshots metadata table.
|
||||
"""
|
||||
qualified = f"{self.catalog}.{self.schema}.{table_name}"
|
||||
sql = f'SELECT * FROM "{qualified}$snapshots"'
|
||||
try:
|
||||
rows = self._execute(sql)
|
||||
return [{"snapshot_id": r[0], "parent_id": r[1], "operation": r[2],
|
||||
"manifest_list": r[3], "summary": r[4]} for r in rows]
|
||||
except Exception:
|
||||
logger.debug("Could not read snapshots for %s (table may be empty)", table_name)
|
||||
return []
|
||||
|
||||
def get_table_partitions(self, table_name: str) -> list[dict[str, Any]]:
|
||||
"""List partition values for an Iceberg table.
|
||||
|
||||
Returns partition metadata from Trino's $partitions metadata table.
|
||||
"""
|
||||
qualified = f"{self.catalog}.{self.schema}.{table_name}"
|
||||
sql = f'SELECT * FROM "{qualified}$partitions"'
|
||||
try:
|
||||
rows = self._execute(sql)
|
||||
return [{"row": r} for r in rows]
|
||||
except Exception:
|
||||
logger.debug("Could not read partitions for %s (table may be empty)", table_name)
|
||||
return []
|
||||
|
||||
def list_tables(self) -> list[str]:
|
||||
"""List all tables in the Iceberg schema."""
|
||||
sql = (
|
||||
f"SELECT table_name FROM {self.catalog}.information_schema.tables "
|
||||
f"WHERE table_schema = '{self.schema}' ORDER BY table_name"
|
||||
)
|
||||
rows = self._execute(sql)
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def drop_table(self, table_name: str) -> None:
|
||||
"""Drop an Iceberg table (for testing/reset purposes)."""
|
||||
qualified = f"{self.catalog}.{self.schema}.{table_name}"
|
||||
logger.warning("Dropping Iceberg table: %s", qualified)
|
||||
self._execute_no_fetch(f"DROP TABLE IF EXISTS {qualified}")
|
||||
|
||||
def sync_table_schema(self, table_name: str) -> list[str]:
|
||||
"""Compare the expected PyArrow schema with the actual Iceberg table schema.
|
||||
|
||||
If columns are missing from the Iceberg table, adds them via ALTER TABLE.
|
||||
Returns a list of columns that were added.
|
||||
|
||||
This supports forward-only schema evolution — columns are never dropped.
|
||||
"""
|
||||
table_def = get_table_def(table_name)
|
||||
existing = self.get_table_schema(table_name)
|
||||
existing_names = {col["column_name"] for col in existing}
|
||||
|
||||
added: list[str] = []
|
||||
qualified = table_def.qualified_name
|
||||
|
||||
for i in range(len(table_def.schema)):
|
||||
col_name = table_def.schema.field(i).name
|
||||
if col_name not in existing_names:
|
||||
trino_type = _arrow_type_to_trino(table_def.schema.field(i).type)
|
||||
alter_sql = f"ALTER TABLE {qualified} ADD COLUMN {col_name} {trino_type}"
|
||||
logger.info("Adding column %s to %s", col_name, qualified)
|
||||
self._execute_no_fetch(alter_sql)
|
||||
added.append(col_name)
|
||||
|
||||
return added
|
||||
|
||||
def sync_all_schemas(self) -> dict[str, list[str]]:
|
||||
"""Sync schemas for all registered tables. Returns table_name -> added columns."""
|
||||
results: dict[str, list[str]] = {}
|
||||
for table_def in get_all_table_defs():
|
||||
try:
|
||||
if self.table_exists(table_def.table_name):
|
||||
added = self.sync_table_schema(table_def.table_name)
|
||||
results[table_def.table_name] = added
|
||||
else:
|
||||
logger.info("Table %s doesn't exist yet, skipping sync", table_def.table_name)
|
||||
results[table_def.table_name] = []
|
||||
except Exception:
|
||||
logger.exception("Failed to sync schema for %s", table_def.table_name)
|
||||
results[table_def.table_name] = []
|
||||
return results
|
||||
|
||||
|
||||
def create_iceberg_manager_from_config(
|
||||
host: str = "localhost",
|
||||
port: int = 8080,
|
||||
user: str = "stonks",
|
||||
) -> IcebergManager:
|
||||
"""Factory that creates an IcebergManager from explicit connection params."""
|
||||
return IcebergManager(host=host, port=port, user=user)
|
||||
@@ -0,0 +1,673 @@
|
||||
"""Lake publisher async job runner — transforms operational data into analytical facts.
|
||||
|
||||
Reads jobs from the QUEUE_LAKE_PUBLISH Redis queue, queries PostgreSQL for
|
||||
operational records, and publishes them as partitioned Parquet files to MinIO
|
||||
via the existing publish_* functions in worker.py.
|
||||
|
||||
Job message format:
|
||||
{"job_type": "<table_name>", "entity_id": "<uuid or ticker>", "dt": "2026-04-11T..."}
|
||||
|
||||
Supported job types:
|
||||
- document: publish a single document metadata fact
|
||||
- document_extraction: publish extraction facts for a document
|
||||
- market_snapshot: publish market bars/quotes from a snapshot
|
||||
- trade_order: publish an order fact
|
||||
- trade_fill: publish fill facts for an order
|
||||
- positions_snapshot: publish daily position snapshots for a broker account
|
||||
- pnl_snapshot: publish daily PnL for a broker account
|
||||
- company_event: publish a company event fact
|
||||
- bulk_documents: publish all unpublished documents since a cutoff
|
||||
- bulk_extractions: publish all unpublished extractions since a cutoff
|
||||
|
||||
Requirements: 9.4, 9.5, 10.1
|
||||
Design ref: Section 4.10 (Lake Publisher), Section 8.4 (Lake publication flow)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
from minio import Minio
|
||||
|
||||
from services.lake_publisher.worker import (
|
||||
publish_document_extraction,
|
||||
publish_document_fact,
|
||||
publish_market_bar,
|
||||
publish_market_quote,
|
||||
publish_trade_order,
|
||||
publish_trade_fill,
|
||||
publish_pnl_daily,
|
||||
publish_documents_batch,
|
||||
publish_document_extractions_batch,
|
||||
publish_positions_daily_batch,
|
||||
)
|
||||
from services.lake_publisher.partitions import partition_values
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_minio, get_pg_pool, get_redis
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import QUEUE_LAKE_PUBLISH, queue_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries for fetching operational data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FETCH_DOCUMENT = """
|
||||
SELECT
|
||||
d.id, d.document_type, d.source_type, d.publisher, d.title,
|
||||
d.url, d.canonical_url, d.language, d.published_at, d.retrieved_at,
|
||||
d.content_hash, d.parse_quality_score,
|
||||
COALESCE(
|
||||
(SELECT dcm.ticker FROM document_company_mentions dcm
|
||||
WHERE dcm.document_id = d.id LIMIT 1),
|
||||
''
|
||||
) AS ticker
|
||||
FROM documents d
|
||||
WHERE d.id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_EXTRACTIONS = """
|
||||
SELECT
|
||||
di.document_id, dir.ticker, dir.relevance, dir.sentiment,
|
||||
dir.impact_score, dir.impact_horizon, dir.catalyst_type,
|
||||
di.confidence, di.novelty_score, di.source_credibility,
|
||||
dir.key_facts, dir.risks, di.macro_themes,
|
||||
di.model_name, di.prompt_version, di.schema_version,
|
||||
di.created_at AS extraction_at,
|
||||
COALESCE(c.legal_name, '') AS company_name
|
||||
FROM document_intelligence di
|
||||
JOIN document_impact_records dir ON dir.intelligence_id = di.id
|
||||
LEFT JOIN companies c ON c.id = dir.company_id
|
||||
WHERE di.document_id = $1::uuid
|
||||
AND di.validation_status = 'valid'
|
||||
"""
|
||||
|
||||
_FETCH_MARKET_SNAPSHOT = """
|
||||
SELECT
|
||||
ms.ticker, ms.snapshot_type, ms.data, ms.source_provider, ms.captured_at
|
||||
FROM market_snapshots ms
|
||||
WHERE ms.id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_ORDER = """
|
||||
SELECT
|
||||
o.id, o.recommendation_id, o.ticker, o.side, o.order_type,
|
||||
o.quantity, o.limit_price, o.status, o.submitted_at,
|
||||
o.fill_price, o.fill_quantity, o.filled_at,
|
||||
COALESCE(ba.account_id, '') AS broker_account,
|
||||
COALESCE(ba.mode, 'paper') AS execution_mode
|
||||
FROM orders o
|
||||
LEFT JOIN broker_accounts ba ON ba.id = o.broker_account_id
|
||||
WHERE o.id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_ORDER_FILLS = """
|
||||
SELECT
|
||||
oe.id AS fill_id, oe.order_id, oe.data, oe.broker_timestamp,
|
||||
o.ticker, o.side,
|
||||
COALESCE(ba.account_id, '') AS broker_account
|
||||
FROM order_events oe
|
||||
JOIN orders o ON o.id = oe.order_id
|
||||
LEFT JOIN broker_accounts ba ON ba.id = o.broker_account_id
|
||||
WHERE oe.order_id = $1::uuid AND oe.event_type = 'fill'
|
||||
"""
|
||||
|
||||
_FETCH_POSITIONS = """
|
||||
SELECT
|
||||
p.ticker, p.quantity, p.avg_entry_price, p.current_price,
|
||||
p.unrealized_pnl, p.realized_pnl,
|
||||
COALESCE(ba.account_id, '') AS broker_account,
|
||||
COALESCE(ba.mode, 'paper') AS execution_mode
|
||||
FROM positions p
|
||||
LEFT JOIN broker_accounts ba ON ba.id = p.broker_account_id
|
||||
WHERE p.broker_account_id = $1::uuid AND p.quantity != 0
|
||||
"""
|
||||
|
||||
_FETCH_BULK_DOCUMENTS = """
|
||||
SELECT
|
||||
d.id, d.document_type, d.source_type, d.publisher, d.title,
|
||||
d.url, d.canonical_url, d.language, d.published_at, d.retrieved_at,
|
||||
d.content_hash, d.parse_quality_score,
|
||||
COALESCE(
|
||||
(SELECT dcm.ticker FROM document_company_mentions dcm
|
||||
WHERE dcm.document_id = d.id LIMIT 1),
|
||||
''
|
||||
) AS ticker
|
||||
FROM documents d
|
||||
WHERE d.created_at >= $1
|
||||
AND d.status IN ('parsed', 'extracted')
|
||||
ORDER BY d.created_at
|
||||
LIMIT 500
|
||||
"""
|
||||
|
||||
_FETCH_BULK_EXTRACTIONS = """
|
||||
SELECT
|
||||
di.document_id, dir.ticker, dir.relevance, dir.sentiment,
|
||||
dir.impact_score, dir.impact_horizon, dir.catalyst_type,
|
||||
di.confidence, di.novelty_score, di.source_credibility,
|
||||
dir.key_facts, dir.risks, di.macro_themes,
|
||||
di.model_name, di.prompt_version, di.schema_version,
|
||||
di.created_at AS extraction_at,
|
||||
COALESCE(c.legal_name, '') AS company_name
|
||||
FROM document_intelligence di
|
||||
JOIN document_impact_records dir ON dir.intelligence_id = di.id
|
||||
LEFT JOIN companies c ON c.id = dir.company_id
|
||||
WHERE di.created_at >= $1
|
||||
AND di.validation_status = 'valid'
|
||||
ORDER BY di.created_at
|
||||
LIMIT 500
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job handlers — each transforms operational rows into lake facts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _jsonb_to_str(val: object) -> str:
|
||||
"""Convert a JSONB column value (list or str) to a comma-separated string."""
|
||||
if val is None:
|
||||
return ""
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
parsed = json.loads(val)
|
||||
if isinstance(parsed, list):
|
||||
return ", ".join(str(x) for x in parsed)
|
||||
return val
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return val
|
||||
if isinstance(val, list):
|
||||
return ", ".join(str(x) for x in val)
|
||||
return str(val)
|
||||
|
||||
|
||||
async def publish_document_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> str:
|
||||
"""Publish a single document metadata fact from PostgreSQL to the lake."""
|
||||
row = await pool.fetchrow(_FETCH_DOCUMENT, entity_id)
|
||||
if row is None:
|
||||
logger.warning("Document %s not found, skipping lake publish", entity_id)
|
||||
return ""
|
||||
|
||||
published_at = row["published_at"] or row["retrieved_at"]
|
||||
return publish_document_fact(
|
||||
client=minio_client,
|
||||
document_id=str(row["id"]),
|
||||
document_type=row["document_type"],
|
||||
source_type=row["source_type"],
|
||||
ticker=row["ticker"] or "",
|
||||
publisher=row["publisher"] or "",
|
||||
title=row["title"] or "",
|
||||
published_at=published_at,
|
||||
content_hash=row["content_hash"],
|
||||
url=row["url"] or "",
|
||||
canonical_url=row["canonical_url"] or "",
|
||||
language=row["language"] or "en",
|
||||
confidence=float(row["parse_quality_score"] or 0.0),
|
||||
retrieved_at=row["retrieved_at"],
|
||||
)
|
||||
|
||||
|
||||
async def publish_extraction_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> list[str]:
|
||||
"""Publish document extraction facts for a document from PostgreSQL to the lake."""
|
||||
rows = await pool.fetch(_FETCH_EXTRACTIONS, entity_id)
|
||||
if not rows:
|
||||
logger.info("No valid extractions for document %s", entity_id)
|
||||
return []
|
||||
|
||||
refs: list[str] = []
|
||||
for row in rows:
|
||||
ref = publish_document_extraction(
|
||||
client=minio_client,
|
||||
document_id=str(row["document_id"]),
|
||||
ticker=row["ticker"],
|
||||
sentiment=row["sentiment"] or "neutral",
|
||||
impact_score=float(row["impact_score"] or 0.0),
|
||||
catalyst_type=row["catalyst_type"] or "other",
|
||||
confidence=float(row["confidence"] or 0.0),
|
||||
extraction_at=row["extraction_at"],
|
||||
model_name=row["model_name"] or "",
|
||||
prompt_version=row["prompt_version"] or "",
|
||||
company_name=row["company_name"] or "",
|
||||
relevance=float(row["relevance"] or 0.0),
|
||||
impact_horizon=row["impact_horizon"] or "",
|
||||
novelty_score=float(row["novelty_score"] or 0.0),
|
||||
source_credibility=float(row["source_credibility"] or 0.0),
|
||||
key_facts=_jsonb_to_str(row["key_facts"]),
|
||||
risks=_jsonb_to_str(row["risks"]),
|
||||
macro_themes=_jsonb_to_str(row["macro_themes"]),
|
||||
schema_version=row["schema_version"] or "",
|
||||
)
|
||||
refs.append(ref)
|
||||
return refs
|
||||
|
||||
|
||||
async def publish_market_snapshot_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> list[str]:
|
||||
"""Publish market bar/quote facts from a market_snapshots row."""
|
||||
row = await pool.fetchrow(_FETCH_MARKET_SNAPSHOT, entity_id)
|
||||
if row is None:
|
||||
logger.warning("Market snapshot %s not found", entity_id)
|
||||
return []
|
||||
|
||||
ticker = row["ticker"]
|
||||
data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"])
|
||||
source = row["source_provider"] or ""
|
||||
captured_at = row["captured_at"]
|
||||
snapshot_type = row["snapshot_type"]
|
||||
refs: list[str] = []
|
||||
|
||||
if snapshot_type == "bar" or snapshot_type == "bars":
|
||||
# Single bar or list of bars
|
||||
bars = data.get("bars", [data]) if "bars" in data else [data]
|
||||
for bar in bars:
|
||||
ref = publish_market_bar(
|
||||
client=minio_client,
|
||||
ticker=ticker,
|
||||
open_price=float(bar.get("open", bar.get("o", 0))),
|
||||
high_price=float(bar.get("high", bar.get("h", 0))),
|
||||
low_price=float(bar.get("low", bar.get("l", 0))),
|
||||
close_price=float(bar.get("close", bar.get("c", 0))),
|
||||
volume=int(bar.get("volume", bar.get("v", 0))),
|
||||
bar_timestamp=captured_at,
|
||||
source=source,
|
||||
vwap=float(bar.get("vwap", bar.get("vw", 0))),
|
||||
trade_count=int(bar.get("trade_count", bar.get("n", 0))),
|
||||
bar_interval=bar.get("interval", "1d"),
|
||||
)
|
||||
refs.append(ref)
|
||||
elif snapshot_type == "quote" or snapshot_type == "quotes":
|
||||
ref = publish_market_quote(
|
||||
client=minio_client,
|
||||
ticker=ticker,
|
||||
bid_price=float(data.get("bid_price", data.get("bp", 0))),
|
||||
ask_price=float(data.get("ask_price", data.get("ap", 0))),
|
||||
last_price=float(data.get("last_price", data.get("lp", 0))),
|
||||
quote_at=captured_at,
|
||||
source=source,
|
||||
bid_size=int(data.get("bid_size", data.get("bs", 0))),
|
||||
ask_size=int(data.get("ask_size", data.get("as", 0))),
|
||||
last_size=int(data.get("last_size", data.get("ls", 0))),
|
||||
)
|
||||
refs.append(ref)
|
||||
|
||||
return refs
|
||||
|
||||
|
||||
async def publish_order_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> str:
|
||||
"""Publish a trade order fact from PostgreSQL to the lake."""
|
||||
row = await pool.fetchrow(_FETCH_ORDER, entity_id)
|
||||
if row is None:
|
||||
logger.warning("Order %s not found", entity_id)
|
||||
return ""
|
||||
|
||||
submitted_at = row["submitted_at"] or datetime.now(timezone.utc)
|
||||
return publish_trade_order(
|
||||
client=minio_client,
|
||||
order_id=str(row["id"]),
|
||||
ticker=row["ticker"],
|
||||
side=row["side"],
|
||||
order_type=row["order_type"],
|
||||
quantity=float(row["quantity"]),
|
||||
limit_price=float(row["limit_price"]) if row["limit_price"] else None,
|
||||
status=row["status"],
|
||||
broker_account=row["broker_account"],
|
||||
submitted_at=submitted_at,
|
||||
recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else "",
|
||||
execution_mode=row["execution_mode"],
|
||||
)
|
||||
|
||||
|
||||
async def publish_fills_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> list[str]:
|
||||
"""Publish trade fill facts for an order from PostgreSQL to the lake."""
|
||||
rows = await pool.fetch(_FETCH_ORDER_FILLS, entity_id)
|
||||
if not rows:
|
||||
logger.info("No fill events for order %s", entity_id)
|
||||
return []
|
||||
|
||||
refs: list[str] = []
|
||||
for row in rows:
|
||||
data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"] or "{}")
|
||||
filled_at = row["broker_timestamp"] or datetime.now(timezone.utc)
|
||||
ref = publish_trade_fill(
|
||||
client=minio_client,
|
||||
fill_id=str(row["fill_id"]),
|
||||
order_id=str(row["order_id"]),
|
||||
ticker=row["ticker"],
|
||||
side=row["side"],
|
||||
fill_price=float(data.get("fill_price", data.get("price", 0))),
|
||||
fill_quantity=float(data.get("fill_quantity", data.get("qty", 0))),
|
||||
broker_account=row["broker_account"],
|
||||
filled_at=filled_at,
|
||||
commission=float(data.get("commission", 0)),
|
||||
)
|
||||
refs.append(ref)
|
||||
return refs
|
||||
|
||||
|
||||
async def publish_positions_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> str:
|
||||
"""Publish daily position snapshots for a broker account."""
|
||||
rows = await pool.fetch(_FETCH_POSITIONS, entity_id)
|
||||
if not rows:
|
||||
logger.info("No open positions for account %s", entity_id)
|
||||
return ""
|
||||
|
||||
snapshot_at = datetime.now(timezone.utc)
|
||||
positions = [
|
||||
{
|
||||
"ticker": row["ticker"],
|
||||
"quantity": float(row["quantity"]),
|
||||
"avg_entry_price": float(row["avg_entry_price"] or 0),
|
||||
"close_price": float(row["current_price"] or 0),
|
||||
"unrealized_pnl": float(row["unrealized_pnl"] or 0),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
broker_account = rows[0]["broker_account"] if rows else ""
|
||||
return publish_positions_daily_batch(
|
||||
client=minio_client,
|
||||
positions=positions,
|
||||
broker_account=broker_account,
|
||||
snapshot_at=snapshot_at,
|
||||
)
|
||||
|
||||
|
||||
async def publish_pnl_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> list[str]:
|
||||
"""Publish daily PnL facts for a broker account's positions."""
|
||||
rows = await pool.fetch(_FETCH_POSITIONS, entity_id)
|
||||
if not rows:
|
||||
logger.info("No positions for PnL snapshot, account %s", entity_id)
|
||||
return []
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
refs: list[str] = []
|
||||
for row in rows:
|
||||
realized = float(row["realized_pnl"] or 0)
|
||||
unrealized = float(row["unrealized_pnl"] or 0)
|
||||
total = realized + unrealized
|
||||
ref = publish_pnl_daily(
|
||||
client=minio_client,
|
||||
ticker=row["ticker"],
|
||||
realized_pnl=realized,
|
||||
unrealized_pnl=unrealized,
|
||||
total_pnl=total,
|
||||
broker_account=row["broker_account"],
|
||||
dt=now,
|
||||
execution_mode=row["execution_mode"],
|
||||
)
|
||||
refs.append(ref)
|
||||
return refs
|
||||
|
||||
|
||||
async def publish_bulk_documents_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
since: datetime,
|
||||
) -> list[str]:
|
||||
"""Publish all documents created since a cutoff as a batch."""
|
||||
rows = await pool.fetch(_FETCH_BULK_DOCUMENTS, since)
|
||||
if not rows:
|
||||
logger.info("No documents to bulk-publish since %s", since)
|
||||
return []
|
||||
|
||||
doc_rows: list[dict[str, object]] = []
|
||||
for row in rows:
|
||||
published_at = row["published_at"] or row["retrieved_at"]
|
||||
doc_rows.append({
|
||||
"document_id": str(row["id"]),
|
||||
"document_type": row["document_type"],
|
||||
"source_type": row["source_type"],
|
||||
"ticker": row["ticker"] or "",
|
||||
"publisher": row["publisher"] or "",
|
||||
"title": row["title"] or "",
|
||||
"url": row["url"] or "",
|
||||
"canonical_url": row["canonical_url"] or "",
|
||||
"language": row["language"] or "en",
|
||||
"published_at": published_at,
|
||||
"retrieved_at": row["retrieved_at"],
|
||||
"content_hash": row["content_hash"],
|
||||
"confidence": float(row["parse_quality_score"] or 0.0),
|
||||
**partition_values(published_at),
|
||||
})
|
||||
|
||||
ref = publish_documents_batch(minio_client, doc_rows, since)
|
||||
return [ref] if ref else []
|
||||
|
||||
|
||||
async def publish_bulk_extractions_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
since: datetime,
|
||||
) -> list[str]:
|
||||
"""Publish all extractions created since a cutoff as a batch."""
|
||||
rows = await pool.fetch(_FETCH_BULK_EXTRACTIONS, since)
|
||||
if not rows:
|
||||
logger.info("No extractions to bulk-publish since %s", since)
|
||||
return []
|
||||
|
||||
extraction_rows: list[dict[str, object]] = []
|
||||
for row in rows:
|
||||
model_ver = row["schema_version"] or row["prompt_version"] or ""
|
||||
extraction_rows.append({
|
||||
"document_id": str(row["document_id"]),
|
||||
"ticker": row["ticker"],
|
||||
"company_name": row["company_name"] or "",
|
||||
"relevance": float(row["relevance"] or 0.0),
|
||||
"sentiment": row["sentiment"] or "neutral",
|
||||
"impact_score": float(row["impact_score"] or 0.0),
|
||||
"impact_horizon": row["impact_horizon"] or "",
|
||||
"catalyst_type": row["catalyst_type"] or "other",
|
||||
"confidence": float(row["confidence"] or 0.0),
|
||||
"novelty_score": float(row["novelty_score"] or 0.0),
|
||||
"source_credibility": float(row["source_credibility"] or 0.0),
|
||||
"key_facts": _jsonb_to_str(row["key_facts"]),
|
||||
"risks": _jsonb_to_str(row["risks"]),
|
||||
"macro_themes": _jsonb_to_str(row["macro_themes"]),
|
||||
"model_name": row["model_name"] or "",
|
||||
"prompt_version": row["prompt_version"] or "",
|
||||
"schema_version": row["schema_version"] or "",
|
||||
"extraction_at": row["extraction_at"],
|
||||
**partition_values(row["extraction_at"], {"model_version": model_ver}),
|
||||
})
|
||||
|
||||
model_ver = extraction_rows[0].get("model_version", "") if extraction_rows else ""
|
||||
ref = publish_document_extractions_batch(
|
||||
minio_client, extraction_rows, since,
|
||||
model_version=str(model_ver),
|
||||
)
|
||||
return [ref] if ref else []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
JOB_TYPES = {
|
||||
"document",
|
||||
"document_extraction",
|
||||
"market_snapshot",
|
||||
"trade_order",
|
||||
"trade_fill",
|
||||
"positions_snapshot",
|
||||
"pnl_snapshot",
|
||||
"company_event",
|
||||
"bulk_documents",
|
||||
"bulk_extractions",
|
||||
}
|
||||
|
||||
|
||||
async def dispatch_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
job: dict[str, str],
|
||||
) -> dict[str, object]:
|
||||
"""Dispatch a lake publish job to the appropriate handler.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
minio_client: MinIO client for writing Parquet files.
|
||||
job: Job dict with at least 'job_type' and 'entity_id'.
|
||||
|
||||
Returns:
|
||||
A result dict with 'job_type', 'entity_id', 'refs' (list of s3 URIs),
|
||||
and 'error' (None on success).
|
||||
"""
|
||||
job_type = job.get("job_type", "")
|
||||
entity_id = job.get("entity_id", "")
|
||||
since_str = job.get("since")
|
||||
|
||||
result: dict[str, object] = {
|
||||
"job_type": job_type,
|
||||
"entity_id": entity_id,
|
||||
"refs": [],
|
||||
"error": None,
|
||||
}
|
||||
|
||||
try:
|
||||
if job_type == "document":
|
||||
ref = await publish_document_job(pool, minio_client, entity_id)
|
||||
result["refs"] = [ref] if ref else []
|
||||
|
||||
elif job_type == "document_extraction":
|
||||
refs = await publish_extraction_job(pool, minio_client, entity_id)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "market_snapshot":
|
||||
refs = await publish_market_snapshot_job(pool, minio_client, entity_id)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "trade_order":
|
||||
ref = await publish_order_job(pool, minio_client, entity_id)
|
||||
result["refs"] = [ref] if ref else []
|
||||
|
||||
elif job_type == "trade_fill":
|
||||
refs = await publish_fills_job(pool, minio_client, entity_id)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "positions_snapshot":
|
||||
ref = await publish_positions_job(pool, minio_client, entity_id)
|
||||
result["refs"] = [ref] if ref else []
|
||||
|
||||
elif job_type == "pnl_snapshot":
|
||||
refs = await publish_pnl_job(pool, minio_client, entity_id)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "bulk_documents":
|
||||
since = datetime.fromisoformat(since_str) if since_str else datetime.now(timezone.utc)
|
||||
refs = await publish_bulk_documents_job(pool, minio_client, since)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "bulk_extractions":
|
||||
since = datetime.fromisoformat(since_str) if since_str else datetime.now(timezone.utc)
|
||||
refs = await publish_bulk_extractions_job(pool, minio_client, since)
|
||||
result["refs"] = refs
|
||||
|
||||
else:
|
||||
result["error"] = f"Unknown job_type: {job_type}"
|
||||
logger.warning("Unknown lake publish job type: %s", job_type)
|
||||
|
||||
except Exception as exc:
|
||||
result["error"] = str(exc)
|
||||
logger.exception("Lake publish job failed: %s/%s", job_type, entity_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async worker loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_worker(
|
||||
pool: asyncpg.Pool,
|
||||
rds: aioredis.Redis,
|
||||
minio_client: Minio,
|
||||
poll_interval: float = 2.0,
|
||||
) -> None:
|
||||
"""Main worker loop — reads jobs from Redis and dispatches them.
|
||||
|
||||
Runs indefinitely until cancelled. Each job is processed sequentially
|
||||
to keep MinIO write ordering predictable.
|
||||
"""
|
||||
queue = queue_key(QUEUE_LAKE_PUBLISH)
|
||||
logger.info("Lake publisher worker started, listening on %s", queue)
|
||||
|
||||
while True:
|
||||
raw = await rds.lpop(queue) # type: ignore[misc]
|
||||
if raw is None:
|
||||
await asyncio.sleep(poll_interval)
|
||||
continue
|
||||
|
||||
try:
|
||||
job = json.loads(str(raw))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.error("Invalid lake publish job payload: %s", raw)
|
||||
continue
|
||||
|
||||
result = await dispatch_job(pool, minio_client, job)
|
||||
refs = result.get("refs") or []
|
||||
error = result.get("error")
|
||||
|
||||
if error:
|
||||
logger.error(
|
||||
"Lake publish job %s/%s failed: %s",
|
||||
result["job_type"], result["entity_id"], error,
|
||||
)
|
||||
else:
|
||||
ref_count = len(refs) if isinstance(refs, list) else 0
|
||||
logger.info(
|
||||
"Lake publish job %s/%s completed: %d facts written",
|
||||
result["job_type"], result["entity_id"], ref_count,
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Entry point for the lake publisher worker process."""
|
||||
config = load_config()
|
||||
pool = await get_pg_pool(config)
|
||||
rds = get_redis(config)
|
||||
minio_client = get_minio(config)
|
||||
|
||||
try:
|
||||
await run_worker(pool, rds, minio_client)
|
||||
finally:
|
||||
await pool.close()
|
||||
await rds.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = load_config()
|
||||
setup_logging("lake_publisher", level=cfg.log_level, json_output=cfg.json_logs)
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Hive-compatible partition layout conventions for the MinIO lakehouse.
|
||||
|
||||
Centralizes partition path generation, partition column injection, and
|
||||
bucket provisioning so that all lake publisher writers produce layouts
|
||||
that Trino's Hive and Iceberg connectors can discover and prune.
|
||||
|
||||
Design ref: Section 5.2, 5.3 (Lakehouse model)
|
||||
Requirements: 9.4, 9.5, N4, N6
|
||||
|
||||
Layout convention:
|
||||
s3://stonks-lakehouse/warehouse/{table_name}/dt={YYYY-MM-DD}[/{extra_key}={value}]/part-{uuid}.parquet
|
||||
|
||||
Rules:
|
||||
- Every fact table is partitioned by ``dt`` (DATE) derived from the row timestamp.
|
||||
- Some tables have a second partition key (e.g. ``model_version``).
|
||||
- Partition columns MUST appear in the Parquet file so Trino can read them
|
||||
without relying solely on path parsing.
|
||||
- File names use a UUID suffix to avoid collisions on concurrent writes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, timezone
|
||||
|
||||
|
||||
LAKEHOUSE_BUCKET = "stonks-lakehouse"
|
||||
WAREHOUSE_PREFIX = "warehouse"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PartitionSpec:
|
||||
"""Describes the partition layout for a single fact table."""
|
||||
|
||||
table_name: str
|
||||
extra_keys: tuple[str, ...] = field(default_factory=tuple)
|
||||
|
||||
@property
|
||||
def all_keys(self) -> tuple[str, ...]:
|
||||
"""Return all partition keys in order (dt first, then extras)."""
|
||||
return ("dt", *self.extra_keys)
|
||||
|
||||
|
||||
# Registry of every analytical fact table and its partition keys.
|
||||
# This is the single source of truth — DDL, publisher, and tests should agree.
|
||||
TABLE_PARTITIONS: dict[str, PartitionSpec] = {
|
||||
"market_bars": PartitionSpec("market_bars"),
|
||||
"market_quotes": PartitionSpec("market_quotes"),
|
||||
"company_events": PartitionSpec("company_events"),
|
||||
"documents": PartitionSpec("documents"),
|
||||
"document_extractions": PartitionSpec("document_extractions", extra_keys=("model_version",)),
|
||||
"trade_signals": PartitionSpec("trade_signals"),
|
||||
"trade_orders": PartitionSpec("trade_orders"),
|
||||
"trade_fills": PartitionSpec("trade_fills"),
|
||||
"positions_daily": PartitionSpec("positions_daily"),
|
||||
"pnl_daily": PartitionSpec("pnl_daily"),
|
||||
"prediction_vs_outcome": PartitionSpec("prediction_vs_outcome", extra_keys=("model_version",)),
|
||||
"model_performance": PartitionSpec("model_performance", extra_keys=("model_version",)),
|
||||
}
|
||||
|
||||
|
||||
def partition_path(
|
||||
table_name: str,
|
||||
dt: datetime | date,
|
||||
extra_partitions: dict[str, str] | None = None,
|
||||
file_id: str | None = None,
|
||||
) -> str:
|
||||
"""Build a Hive-compatible object path for a Parquet file.
|
||||
|
||||
Args:
|
||||
table_name: Logical fact table name (must be in TABLE_PARTITIONS).
|
||||
dt: Row timestamp or date used to derive the ``dt=`` partition.
|
||||
extra_partitions: Additional partition key/value pairs (e.g. model_version).
|
||||
file_id: Optional override for the file suffix (defaults to a UUID4).
|
||||
|
||||
Returns:
|
||||
Object key relative to the bucket root, e.g.
|
||||
``warehouse/trade_signals/dt=2026-04-11/part-<uuid>.parquet``
|
||||
"""
|
||||
spec = TABLE_PARTITIONS.get(table_name)
|
||||
if spec is None:
|
||||
raise ValueError(f"Unknown table: {table_name}. Register it in TABLE_PARTITIONS.")
|
||||
|
||||
if isinstance(dt, datetime):
|
||||
dt_str = dt.strftime("%Y-%m-%d")
|
||||
else:
|
||||
dt_str = dt.isoformat()
|
||||
|
||||
segments = [WAREHOUSE_PREFIX, table_name, f"dt={dt_str}"]
|
||||
|
||||
# Append extra partition directories in the order declared by the spec.
|
||||
extras = extra_partitions or {}
|
||||
for key in spec.extra_keys:
|
||||
value = extras.get(key, "__NONE__")
|
||||
segments.append(f"{key}={value}")
|
||||
|
||||
suffix = file_id or uuid.uuid4().hex[:16]
|
||||
segments.append(f"part-{suffix}.parquet")
|
||||
|
||||
return "/".join(segments)
|
||||
|
||||
|
||||
def partition_values(
|
||||
dt: datetime | date,
|
||||
extra_partitions: dict[str, str] | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""Return partition column values to inject into Parquet row data.
|
||||
|
||||
Trino's Hive connector can read partition values from the directory path,
|
||||
but embedding them in the Parquet file as well ensures compatibility with
|
||||
engines that don't parse Hive paths (e.g. plain PyArrow reads, DuckDB).
|
||||
|
||||
Returns a dict like ``{"dt": date(2026, 4, 11), "model_version": "v2"}``.
|
||||
"""
|
||||
if isinstance(dt, datetime):
|
||||
dt_date = dt.date()
|
||||
else:
|
||||
dt_date = dt
|
||||
|
||||
values: dict[str, object] = {"dt": dt_date}
|
||||
if extra_partitions:
|
||||
values.update(extra_partitions)
|
||||
return values
|
||||
|
||||
|
||||
def s3_uri(path: str) -> str:
|
||||
"""Build an s3:// URI from a bucket-relative object path."""
|
||||
return f"s3://{LAKEHOUSE_BUCKET}/{path}"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,858 @@
|
||||
"""HTML-to-text parsing pipeline using BeautifulSoup.
|
||||
|
||||
Provides structured HTML parsing with boilerplate removal, metadata extraction,
|
||||
outbound link extraction, and quality scoring. Inspired by Noctipede crawler
|
||||
patterns: BeautifulSoup + content hashing, boilerplate stripping, quality scoring.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
logger = logging.getLogger("html_parser")
|
||||
|
||||
# Tags that never contain useful article content
|
||||
STRIP_TAGS = [
|
||||
"script", "style", "nav", "footer", "header", "aside",
|
||||
"iframe", "noscript", "svg", "form", "button",
|
||||
]
|
||||
|
||||
# CSS class / id substrings that signal boilerplate containers
|
||||
BOILERPLATE_SIGNALS = [
|
||||
"sidebar", "widget", "advert", "promo", "newsletter",
|
||||
"social-share", "share-bar", "related-posts", "comment",
|
||||
"cookie", "popup", "modal", "banner", "breadcrumb",
|
||||
"pagination", "nav-", "menu", "toolbar", "signup",
|
||||
"subscribe", "follow-us", "social-media", "share-button",
|
||||
"ad-slot", "ad-container", "sponsored",
|
||||
]
|
||||
|
||||
# Regex patterns for residual boilerplate in extracted text
|
||||
BOILERPLATE_TEXT_PATTERNS = [
|
||||
re.compile(r"(?i)subscribe to our newsletter.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)click here to read more.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)advertisement\s*\n?"),
|
||||
re.compile(r"(?i)copyright ©.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)all rights reserved.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)terms of (use|service).*?(?:\n|$)"),
|
||||
re.compile(r"(?i)privacy policy.*?(?:\n|$)"),
|
||||
re.compile(r"\s*\[.*?ad.*?\]\s*", re.IGNORECASE),
|
||||
re.compile(r"(?i)sign up for .*?(?:\n|$)"),
|
||||
re.compile(r"(?i)follow us on .*?(?:\n|$)"),
|
||||
re.compile(r"(?i)share this (article|story|post).*?(?:\n|$)"),
|
||||
re.compile(r"(?i)read more:?\s*$"),
|
||||
re.compile(r"(?i)recommended for you.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)you may also like.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)trending now.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)most (popular|read).*?(?:\n|$)"),
|
||||
re.compile(r"(?i)^tags:\s*$"),
|
||||
re.compile(r"(?i)^\s*photo\s*:.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)^\s*image\s*(credit|source|courtesy)\s*:.*?(?:\n|$)"),
|
||||
]
|
||||
|
||||
# Selectors for article body candidates, in priority order
|
||||
ARTICLE_SELECTORS = [
|
||||
"article",
|
||||
"[role='main']",
|
||||
".article-body",
|
||||
".post-content",
|
||||
".entry-content",
|
||||
".story-body",
|
||||
".article-content",
|
||||
"#article-body",
|
||||
"#story-body",
|
||||
".article-text",
|
||||
".post-body",
|
||||
".content-body",
|
||||
"main",
|
||||
]
|
||||
|
||||
# Minimum text density (text chars / total chars including markup) for a block
|
||||
# to be considered content-rich rather than boilerplate
|
||||
_MIN_TEXT_DENSITY = 0.25
|
||||
|
||||
# Minimum word count for a block to be a viable body candidate
|
||||
_MIN_BLOCK_WORDS = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualitySignals:
|
||||
"""Individual quality signals contributing to the overall parse score.
|
||||
|
||||
Each signal is a float in [0, 1] representing how well the parsed
|
||||
content performs on that dimension.
|
||||
|
||||
Requirements: 4.3
|
||||
"""
|
||||
word_count_signal: float = 0.0
|
||||
diversity_signal: float = 0.0
|
||||
sentence_signal: float = 0.0
|
||||
paragraph_signal: float = 0.0
|
||||
body_found_signal: float = 0.0
|
||||
metadata_signal: float = 0.0
|
||||
|
||||
def as_dict(self) -> dict[str, float]:
|
||||
return {
|
||||
"word_count": self.word_count_signal,
|
||||
"diversity": self.diversity_signal,
|
||||
"sentence": self.sentence_signal,
|
||||
"paragraph": self.paragraph_signal,
|
||||
"body_found": self.body_found_signal,
|
||||
"metadata": self.metadata_signal,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompanyMention:
|
||||
"""A detected company mention in parsed text.
|
||||
|
||||
Requirements: 1.3, 4.1
|
||||
"""
|
||||
company_id: str
|
||||
ticker: str
|
||||
mention_type: str # ticker, legal_name, alias, brand
|
||||
confidence: float
|
||||
match_count: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedDocument:
|
||||
"""Result of HTML-to-text parsing pipeline."""
|
||||
body_text: str = ""
|
||||
title: str = ""
|
||||
author: str = ""
|
||||
publisher: str = ""
|
||||
published_at: str | None = None
|
||||
canonical_url: str | None = None
|
||||
language: str = "en"
|
||||
description: str = ""
|
||||
document_type: str = "article"
|
||||
outbound_links: list[str] = field(default_factory=list)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
mentioned_companies: list[CompanyMention] = field(default_factory=list)
|
||||
quality_score: float = 0.0
|
||||
confidence: str = "low"
|
||||
word_count: int = 0
|
||||
quality_signals: QualitySignals = field(default_factory=QualitySignals)
|
||||
low_quality_flag: bool = False
|
||||
quality_warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def _attr_str(tag: Tag, attr: str) -> str:
|
||||
"""Safely get a tag attribute as a joined string."""
|
||||
val = tag.get(attr, "")
|
||||
if isinstance(val, list):
|
||||
return " ".join(val)
|
||||
return str(val) if val else ""
|
||||
|
||||
|
||||
def _is_boilerplate_container(tag: Tag) -> bool:
|
||||
"""Check if a tag looks like a boilerplate container by class/id."""
|
||||
cls = _attr_str(tag, "class").lower()
|
||||
tag_id = _attr_str(tag, "id").lower()
|
||||
combined = f"{cls} {tag_id}"
|
||||
return any(sig in combined for sig in BOILERPLATE_SIGNALS)
|
||||
|
||||
|
||||
def _strip_boilerplate_tags(soup: BeautifulSoup) -> None:
|
||||
"""Remove known non-content tags and boilerplate containers in-place."""
|
||||
for tag_name in STRIP_TAGS:
|
||||
for tag in soup.find_all(tag_name):
|
||||
tag.decompose()
|
||||
|
||||
for tag in soup.find_all(True):
|
||||
if _is_boilerplate_container(tag):
|
||||
tag.decompose()
|
||||
|
||||
|
||||
def _reduce_boilerplate_text(text: str) -> str:
|
||||
"""Apply regex patterns to strip residual boilerplate from extracted text."""
|
||||
for pattern in BOILERPLATE_TEXT_PATTERNS:
|
||||
text = pattern.sub("", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _text_density(tag: Tag) -> float:
|
||||
"""Compute text density for a tag: ratio of text length to total markup length.
|
||||
|
||||
Higher density means more actual text relative to HTML structure,
|
||||
which is a strong signal for content blocks vs boilerplate.
|
||||
|
||||
Requirements: 4.2
|
||||
"""
|
||||
markup_len = len(str(tag))
|
||||
if markup_len == 0:
|
||||
return 0.0
|
||||
text_len = len(tag.get_text(strip=True))
|
||||
return text_len / markup_len
|
||||
|
||||
|
||||
def _link_density(tag: Tag) -> float:
|
||||
"""Compute link density: ratio of text inside <a> tags to total text.
|
||||
|
||||
High link density signals navigation/boilerplate blocks (menus, sidebars).
|
||||
Low link density signals content paragraphs.
|
||||
|
||||
Requirements: 4.2
|
||||
"""
|
||||
total_text = len(tag.get_text(strip=True))
|
||||
if total_text == 0:
|
||||
return 1.0
|
||||
link_text = sum(len(a.get_text(strip=True)) for a in tag.find_all("a"))
|
||||
return link_text / total_text
|
||||
|
||||
|
||||
def _block_score(tag: Tag) -> float:
|
||||
"""Score a block element as a body candidate using text density heuristics.
|
||||
|
||||
Combines text density, link density, paragraph count, and word count
|
||||
into a composite score. Higher is more likely to be the article body.
|
||||
|
||||
Requirements: 4.2
|
||||
"""
|
||||
text = tag.get_text(strip=True)
|
||||
word_count = len(text.split())
|
||||
if word_count < _MIN_BLOCK_WORDS:
|
||||
return 0.0
|
||||
|
||||
td = _text_density(tag)
|
||||
ld = _link_density(tag)
|
||||
p_count = len(tag.find_all("p"))
|
||||
|
||||
# Base score from text density (0-1), penalized by link density
|
||||
score = td * (1.0 - ld)
|
||||
|
||||
# Bonus for paragraph-rich blocks (structured article content)
|
||||
if p_count >= 2:
|
||||
score += 0.1 * min(p_count, 10)
|
||||
|
||||
# Bonus for word count (log-scaled to avoid runaway scores)
|
||||
score += 0.05 * math.log(max(word_count, 1))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def _find_article_body(soup: BeautifulSoup) -> Tag | None:
|
||||
"""Find the most likely article body element.
|
||||
|
||||
First tries semantic selectors (article, [role=main], etc.).
|
||||
If no semantic match, falls back to text-density scoring across
|
||||
candidate block elements to find the content-richest container.
|
||||
|
||||
Requirements: 4.2
|
||||
"""
|
||||
# Priority 1: semantic selectors
|
||||
for selector in ARTICLE_SELECTORS:
|
||||
result = soup.select_one(selector)
|
||||
if result:
|
||||
text = result.get_text(strip=True)
|
||||
if len(text.split()) >= _MIN_BLOCK_WORDS:
|
||||
return result
|
||||
|
||||
# Priority 2: text-density scoring on block-level containers
|
||||
candidates: list[tuple[float, Tag]] = []
|
||||
for tag in soup.find_all(["div", "section", "td"]):
|
||||
score = _block_score(tag)
|
||||
if score > 0:
|
||||
candidates.append((score, tag))
|
||||
|
||||
if candidates:
|
||||
candidates.sort(key=lambda x: x[0], reverse=True)
|
||||
return candidates[0][1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _collapse_whitespace(text: str) -> str:
|
||||
"""Collapse runs of blank lines into single separators."""
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
result: list[str] = []
|
||||
prev_blank = False
|
||||
for line in lines:
|
||||
if not line:
|
||||
if not prev_blank:
|
||||
result.append("")
|
||||
prev_blank = True
|
||||
else:
|
||||
result.append(line)
|
||||
prev_blank = False
|
||||
return "\n".join(result).strip()
|
||||
|
||||
|
||||
def _remove_short_orphan_lines(text: str, min_words: int = 3) -> str:
|
||||
"""Remove very short orphan lines that are likely UI fragments or captions.
|
||||
|
||||
Lines shorter than min_words that don't end with sentence punctuation
|
||||
are stripped. This catches leftover button labels, image captions,
|
||||
and navigation fragments.
|
||||
|
||||
Requirements: 4.2
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
kept: list[str] = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
words = stripped.split()
|
||||
if len(words) < min_words and not stripped.endswith((".", "!", "?", ":")):
|
||||
continue
|
||||
kept.append(line)
|
||||
return "\n".join(kept)
|
||||
|
||||
|
||||
def _detect_repeated_blocks(text: str, min_len: int = 40) -> str:
|
||||
"""Remove repeated text blocks that appear more than once.
|
||||
|
||||
Template text (disclaimers, repeated footers) often appears verbatim
|
||||
in multiple places. This strips exact duplicate blocks.
|
||||
|
||||
Requirements: 4.2
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
seen: dict[str, int] = {}
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if len(stripped) >= min_len:
|
||||
seen[stripped] = seen.get(stripped, 0) + 1
|
||||
|
||||
duplicates = {k for k, v in seen.items() if v > 1}
|
||||
if not duplicates:
|
||||
return text
|
||||
|
||||
kept: list[str] = []
|
||||
emitted: set[str] = set()
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped in duplicates:
|
||||
if stripped not in emitted:
|
||||
kept.append(line)
|
||||
emitted.add(stripped)
|
||||
# Skip subsequent duplicates
|
||||
else:
|
||||
kept.append(line)
|
||||
return "\n".join(kept)
|
||||
|
||||
|
||||
def extract_body_text(html: str) -> str:
|
||||
"""Extract main body text from HTML with boilerplate removal.
|
||||
|
||||
Pipeline:
|
||||
1. Strip non-content tags (script, style, nav, footer, etc.)
|
||||
2. Strip boilerplate containers by class/id signals
|
||||
3. Find article body via semantic selectors or text-density scoring
|
||||
4. Extract text from best candidate
|
||||
5. Remove residual boilerplate via regex patterns
|
||||
6. Remove short orphan lines (UI fragments)
|
||||
7. Detect and collapse repeated template blocks
|
||||
8. Collapse whitespace
|
||||
|
||||
Requirements: 4.1, 4.2
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
_strip_boilerplate_tags(soup)
|
||||
|
||||
article = _find_article_body(soup)
|
||||
if article:
|
||||
raw_text = article.get_text(separator="\n", strip=True)
|
||||
else:
|
||||
body = soup.find("body")
|
||||
raw_text = (body or soup).get_text(separator="\n", strip=True)
|
||||
|
||||
# Multi-stage text cleaning
|
||||
text = _reduce_boilerplate_text(raw_text)
|
||||
text = _remove_short_orphan_lines(text)
|
||||
text = _detect_repeated_blocks(text)
|
||||
text = _collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def extract_metadata(html: str, url: str = "") -> dict[str, str | None]:
|
||||
"""Extract document metadata from HTML head elements.
|
||||
|
||||
Extracts title, author, publisher, published date, canonical URL,
|
||||
language, description, and tags/keywords.
|
||||
|
||||
Requirements: 4.1
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
meta: dict[str, str | None] = {}
|
||||
|
||||
# Title: og:title > <title>
|
||||
og_title = soup.find("meta", property="og:title")
|
||||
if og_title and og_title.get("content"):
|
||||
content = og_title["content"]
|
||||
meta["title"] = content.strip() if isinstance(content, str) else ""
|
||||
elif soup.title and soup.title.string:
|
||||
meta["title"] = soup.title.string.strip()
|
||||
else:
|
||||
meta["title"] = ""
|
||||
|
||||
# Author
|
||||
author_tag = soup.find("meta", attrs={"name": "author"})
|
||||
if author_tag and author_tag.get("content"):
|
||||
content = author_tag["content"]
|
||||
meta["author"] = content.strip() if isinstance(content, str) else ""
|
||||
else:
|
||||
meta["author"] = ""
|
||||
|
||||
# Publisher: og:site_name > hostname
|
||||
site_name = soup.find("meta", property="og:site_name")
|
||||
if site_name and site_name.get("content"):
|
||||
content = site_name["content"]
|
||||
meta["publisher"] = content.strip() if isinstance(content, str) else ""
|
||||
else:
|
||||
meta["publisher"] = urlparse(url).hostname or "" if url else ""
|
||||
|
||||
# Published date: article:published_time > JSON-LD datePublished
|
||||
pub_time = soup.find("meta", property="article:published_time")
|
||||
if pub_time and pub_time.get("content"):
|
||||
content = pub_time["content"]
|
||||
meta["published_at"] = content.strip() if isinstance(content, str) else None
|
||||
else:
|
||||
meta["published_at"] = _extract_jsonld_date(soup)
|
||||
|
||||
# Canonical URL
|
||||
canonical = soup.find("link", rel="canonical")
|
||||
if canonical and canonical.get("href"):
|
||||
meta["canonical_url"] = str(canonical["href"])
|
||||
else:
|
||||
og_url = soup.find("meta", property="og:url")
|
||||
if og_url and og_url.get("content"):
|
||||
meta["canonical_url"] = str(og_url["content"])
|
||||
else:
|
||||
meta["canonical_url"] = url or None
|
||||
|
||||
# Language
|
||||
html_tag = soup.find("html")
|
||||
if html_tag and html_tag.get("lang"):
|
||||
lang = html_tag["lang"]
|
||||
meta["language"] = str(lang)[:5] if lang else "en"
|
||||
else:
|
||||
meta["language"] = "en"
|
||||
|
||||
# Description
|
||||
desc = soup.find("meta", property="og:description") or soup.find(
|
||||
"meta", attrs={"name": "description"}
|
||||
)
|
||||
if desc and desc.get("content"):
|
||||
content = desc["content"]
|
||||
meta["description"] = content.strip() if isinstance(content, str) else ""
|
||||
else:
|
||||
meta["description"] = ""
|
||||
|
||||
# Tags / keywords
|
||||
keywords = soup.find("meta", attrs={"name": "keywords"})
|
||||
if keywords and keywords.get("content"):
|
||||
content = keywords["content"]
|
||||
raw = content.strip() if isinstance(content, str) else ""
|
||||
meta["tags"] = raw # comma-separated string
|
||||
else:
|
||||
meta["tags"] = ""
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
def _extract_jsonld_date(soup: BeautifulSoup) -> str | None:
|
||||
"""Try to extract datePublished from JSON-LD script tags."""
|
||||
for script in soup.find_all("script", type="application/ld+json"):
|
||||
if script.string and "datePublished" in script.string:
|
||||
try:
|
||||
ld = json.loads(script.string)
|
||||
if isinstance(ld, dict) and "datePublished" in ld:
|
||||
return str(ld["datePublished"])
|
||||
if isinstance(ld, list):
|
||||
for item in ld:
|
||||
if isinstance(item, dict) and "datePublished" in item:
|
||||
return str(item["datePublished"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def extract_outbound_links(html: str, base_url: str = "") -> list[str]:
|
||||
"""Extract outbound links from HTML, filtering out self-references.
|
||||
|
||||
Requirements: 4.1
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
base_host = urlparse(base_url).hostname or "" if base_url else ""
|
||||
links: list[str] = []
|
||||
|
||||
for a_tag in soup.find_all("a", href=True):
|
||||
href = str(a_tag["href"]).strip()
|
||||
if not href or href.startswith("#") or href.startswith("javascript:"):
|
||||
continue
|
||||
parsed = urlparse(href)
|
||||
# Only include absolute URLs that point to different hosts
|
||||
if parsed.scheme in ("http", "https") and parsed.hostname:
|
||||
if parsed.hostname != base_host:
|
||||
links.append(href)
|
||||
|
||||
# Dedupe while preserving order
|
||||
seen: set[str] = set()
|
||||
unique: list[str] = []
|
||||
for link in links:
|
||||
if link not in seen:
|
||||
seen.add(link)
|
||||
unique.append(link)
|
||||
return unique
|
||||
|
||||
|
||||
def _count_sentences(text: str) -> int:
|
||||
"""Count approximate sentence count by terminal punctuation."""
|
||||
return len(re.findall(r"[.!?]+(?:\s|$)", text))
|
||||
|
||||
|
||||
def _count_paragraphs(text: str) -> int:
|
||||
"""Count non-empty paragraph blocks separated by blank lines."""
|
||||
blocks = re.split(r"\n\s*\n", text.strip())
|
||||
return sum(1 for b in blocks if len(b.strip().split()) >= 5)
|
||||
|
||||
|
||||
def score_parse_quality(
|
||||
text: str,
|
||||
*,
|
||||
body_found: bool = True,
|
||||
has_title: bool = False,
|
||||
has_author: bool = False,
|
||||
has_publisher: bool = False,
|
||||
has_published_at: bool = False,
|
||||
) -> tuple[float, str, QualitySignals, list[str]]:
|
||||
"""Score parse quality using multiple content and metadata signals.
|
||||
|
||||
Returns (score, confidence_label, signals, warnings).
|
||||
|
||||
Signals considered:
|
||||
- word_count_signal: length of extracted text
|
||||
- diversity_signal: vocabulary richness (unique/total words)
|
||||
- sentence_signal: presence of proper sentence structure
|
||||
- paragraph_signal: multi-paragraph structure
|
||||
- body_found_signal: whether a semantic article body was located
|
||||
- metadata_signal: presence of title, author, publisher, date
|
||||
|
||||
Requirements: 4.3
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
words = text.split()
|
||||
word_count = len(words)
|
||||
|
||||
# --- word count signal ---
|
||||
if word_count < 20:
|
||||
wc_sig = 0.1
|
||||
warnings.append("very_short_text")
|
||||
elif word_count < 50:
|
||||
wc_sig = 0.3
|
||||
warnings.append("short_text")
|
||||
elif word_count < 150:
|
||||
wc_sig = 0.6
|
||||
elif word_count < 300:
|
||||
wc_sig = 0.8
|
||||
else:
|
||||
wc_sig = 1.0
|
||||
|
||||
# --- diversity signal ---
|
||||
if word_count > 0:
|
||||
unique = len(set(w.lower() for w in words))
|
||||
diversity = unique / word_count
|
||||
else:
|
||||
diversity = 0.0
|
||||
if diversity < 0.2:
|
||||
div_sig = 0.2
|
||||
if word_count >= 20:
|
||||
warnings.append("low_vocabulary_diversity")
|
||||
elif diversity < 0.4:
|
||||
div_sig = 0.5
|
||||
else:
|
||||
div_sig = 1.0
|
||||
|
||||
# --- sentence signal ---
|
||||
sentence_count = _count_sentences(text)
|
||||
if sentence_count == 0:
|
||||
sent_sig = 0.1
|
||||
if word_count >= 20:
|
||||
warnings.append("no_sentence_structure")
|
||||
elif sentence_count < 3:
|
||||
sent_sig = 0.5
|
||||
else:
|
||||
sent_sig = 1.0
|
||||
|
||||
# --- paragraph signal ---
|
||||
para_count = _count_paragraphs(text)
|
||||
if para_count == 0:
|
||||
para_sig = 0.2
|
||||
elif para_count == 1:
|
||||
para_sig = 0.5
|
||||
else:
|
||||
para_sig = 1.0
|
||||
|
||||
# --- body found signal ---
|
||||
body_sig = 1.0 if body_found else 0.3
|
||||
if not body_found:
|
||||
warnings.append("no_article_body_found")
|
||||
|
||||
# --- metadata signal ---
|
||||
meta_hits = sum([has_title, has_author, has_publisher, has_published_at])
|
||||
meta_sig = meta_hits / 4.0
|
||||
|
||||
signals = QualitySignals(
|
||||
word_count_signal=wc_sig,
|
||||
diversity_signal=div_sig,
|
||||
sentence_signal=sent_sig,
|
||||
paragraph_signal=para_sig,
|
||||
body_found_signal=body_sig,
|
||||
metadata_signal=meta_sig,
|
||||
)
|
||||
|
||||
# Weighted composite score
|
||||
score = (
|
||||
0.30 * wc_sig
|
||||
+ 0.15 * div_sig
|
||||
+ 0.15 * sent_sig
|
||||
+ 0.10 * para_sig
|
||||
+ 0.20 * body_sig
|
||||
+ 0.10 * meta_sig
|
||||
)
|
||||
score = round(min(score, 0.95), 2)
|
||||
|
||||
# Confidence label
|
||||
if score < 0.35:
|
||||
confidence = "low"
|
||||
elif score < 0.65:
|
||||
confidence = "medium"
|
||||
else:
|
||||
confidence = "high"
|
||||
|
||||
return score, confidence, signals, warnings
|
||||
|
||||
|
||||
def score_quality(text: str) -> tuple[float, str]:
|
||||
"""Score parse quality based on extracted text characteristics.
|
||||
|
||||
Returns (score, confidence_label) where confidence is low/medium/high.
|
||||
Thin wrapper around score_parse_quality for backward compatibility.
|
||||
|
||||
Requirements: 4.3
|
||||
"""
|
||||
score, confidence, _signals, _warnings = score_parse_quality(text)
|
||||
return score, confidence
|
||||
|
||||
|
||||
def infer_document_type(html: str, url: str = "") -> str:
|
||||
"""Infer document type from URL patterns and HTML content.
|
||||
|
||||
Requirements: 4.1
|
||||
"""
|
||||
url_lower = url.lower()
|
||||
if any(kw in url_lower for kw in ["sec.gov", "edgar", "filing", "10-k", "10-q", "8-k"]):
|
||||
return "filing"
|
||||
if any(kw in url_lower for kw in ["transcript", "earnings-call", "earnings_call"]):
|
||||
return "transcript"
|
||||
if any(kw in url_lower for kw in ["press-release", "press_release", "newsroom"]):
|
||||
return "press_release"
|
||||
# html reserved for future content-based inference
|
||||
_ = html
|
||||
return "article"
|
||||
|
||||
|
||||
def parse_html(html: str, url: str = "", aliases: list[dict[str, str]] | None = None) -> ParsedDocument:
|
||||
"""Full HTML-to-text parsing pipeline.
|
||||
|
||||
Combines body extraction, metadata extraction, link extraction,
|
||||
quality scoring, document type inference, and company mention
|
||||
detection into a single result.
|
||||
|
||||
Requirements: 1.3, 4.1, 4.2, 4.3
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
_strip_boilerplate_tags(soup)
|
||||
|
||||
article = _find_article_body(soup)
|
||||
body_found = article is not None
|
||||
if article:
|
||||
raw_text = article.get_text(separator="\n", strip=True)
|
||||
else:
|
||||
body = soup.find("body")
|
||||
raw_text = (body or soup).get_text(separator="\n", strip=True)
|
||||
|
||||
# Multi-stage text cleaning
|
||||
text = _reduce_boilerplate_text(raw_text)
|
||||
text = _remove_short_orphan_lines(text)
|
||||
text = _detect_repeated_blocks(text)
|
||||
text = _collapse_whitespace(text)
|
||||
|
||||
metadata = extract_metadata(html, url)
|
||||
outbound_links = extract_outbound_links(html, url)
|
||||
doc_type = infer_document_type(html, url)
|
||||
word_count = len(text.split())
|
||||
|
||||
tags_raw = metadata.get("tags", "") or ""
|
||||
tags = [t.strip() for t in tags_raw.split(",") if t.strip()] if tags_raw else []
|
||||
|
||||
# Rich quality scoring with all available signals
|
||||
quality, confidence, signals, warnings = score_parse_quality(
|
||||
text,
|
||||
body_found=body_found,
|
||||
has_title=bool(metadata.get("title")),
|
||||
has_author=bool(metadata.get("author")),
|
||||
has_publisher=bool(metadata.get("publisher")),
|
||||
has_published_at=bool(metadata.get("published_at")),
|
||||
)
|
||||
|
||||
low_quality_flag = confidence == "low"
|
||||
|
||||
# Company mention detection
|
||||
mentioned: list[CompanyMention] = []
|
||||
if aliases and text:
|
||||
# Search title + body for mentions
|
||||
search_text = f"{metadata.get('title', '')} {text}"
|
||||
raw_mentions = detect_company_mentions(search_text, aliases)
|
||||
for m in raw_mentions:
|
||||
mentioned.append(CompanyMention(
|
||||
company_id=str(m["company_id"]),
|
||||
ticker=str(m["ticker"]),
|
||||
mention_type=str(m["mention_type"]),
|
||||
confidence=float(m["confidence"]),
|
||||
match_count=int(m["match_count"]),
|
||||
))
|
||||
|
||||
return ParsedDocument(
|
||||
body_text=text,
|
||||
title=metadata.get("title", "") or "",
|
||||
author=metadata.get("author", "") or "",
|
||||
publisher=metadata.get("publisher", "") or "",
|
||||
published_at=metadata.get("published_at"),
|
||||
canonical_url=metadata.get("canonical_url"),
|
||||
language=metadata.get("language", "en") or "en",
|
||||
description=metadata.get("description", "") or "",
|
||||
document_type=doc_type,
|
||||
outbound_links=outbound_links,
|
||||
tags=tags,
|
||||
mentioned_companies=mentioned,
|
||||
quality_score=quality,
|
||||
confidence=confidence,
|
||||
word_count=word_count,
|
||||
quality_signals=signals,
|
||||
low_quality_flag=low_quality_flag,
|
||||
quality_warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AliasEntry:
|
||||
"""A company alias used for mention detection."""
|
||||
company_id: str
|
||||
alias: str
|
||||
alias_type: str = "alias"
|
||||
ticker: str = ""
|
||||
|
||||
|
||||
# Confidence by alias type — tickers are most precise, brands least
|
||||
_CONFIDENCE_BY_TYPE: dict[str, float] = {
|
||||
"ticker": 0.9,
|
||||
"legal_name": 0.85,
|
||||
"alias": 0.7,
|
||||
"brand": 0.6,
|
||||
}
|
||||
|
||||
|
||||
def _build_alias_entries(aliases: list[dict[str, str]]) -> list[AliasEntry]:
|
||||
"""Convert raw alias dicts to typed AliasEntry objects."""
|
||||
entries: list[AliasEntry] = []
|
||||
for a in aliases:
|
||||
alias_val = a.get("alias", "")
|
||||
if not alias_val:
|
||||
continue
|
||||
entries.append(AliasEntry(
|
||||
company_id=a.get("company_id", ""),
|
||||
alias=alias_val,
|
||||
alias_type=a.get("alias_type", "alias"),
|
||||
ticker=a.get("ticker", ""),
|
||||
))
|
||||
return entries
|
||||
|
||||
|
||||
def _count_matches(text: str, pattern: re.Pattern[str]) -> int:
|
||||
"""Count non-overlapping matches of pattern in text."""
|
||||
return len(pattern.findall(text))
|
||||
|
||||
|
||||
def detect_company_mentions(
|
||||
text: str,
|
||||
aliases: list[dict[str, str]],
|
||||
) -> list[dict[str, str | float | int]]:
|
||||
"""Detect company mentions using ticker, alias, and name matching.
|
||||
|
||||
Matching strategy by alias length:
|
||||
- 1-2 chars: case-sensitive word-boundary match (avoids "A" matching "a")
|
||||
- 3-4 chars: case-insensitive word-boundary match (standard tickers)
|
||||
- 5+ chars: case-insensitive substring match (company names, brands)
|
||||
|
||||
Confidence varies by alias_type: ticker > legal_name > alias > brand.
|
||||
Multiple alias hits for the same company are deduplicated, keeping the
|
||||
highest-confidence match and summing match counts.
|
||||
|
||||
Requirements: 1.3, 4.1
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
entries = _build_alias_entries(aliases)
|
||||
text_upper = text.upper()
|
||||
|
||||
# Track best match per company: company_id -> (confidence, ticker, mention_type, count)
|
||||
best: dict[str, tuple[float, str, str, int]] = {}
|
||||
|
||||
for entry in entries:
|
||||
alias = entry.alias
|
||||
alias_type = entry.alias_type
|
||||
base_confidence = _CONFIDENCE_BY_TYPE.get(alias_type, 0.7)
|
||||
|
||||
match_count = 0
|
||||
|
||||
if len(alias) <= 2:
|
||||
# Very short: case-sensitive word boundary
|
||||
pattern = re.compile(r"\b" + re.escape(alias) + r"\b")
|
||||
match_count = _count_matches(text, pattern)
|
||||
elif len(alias) <= 4:
|
||||
# Standard ticker length: case-insensitive word boundary
|
||||
pattern = re.compile(r"\b" + re.escape(alias.upper()) + r"\b")
|
||||
match_count = _count_matches(text_upper, pattern)
|
||||
else:
|
||||
# Longer names: case-insensitive substring
|
||||
alias_up = alias.upper()
|
||||
match_count = text_upper.count(alias_up)
|
||||
|
||||
if match_count == 0:
|
||||
continue
|
||||
|
||||
cid = entry.company_id
|
||||
existing = best.get(cid)
|
||||
if existing is None:
|
||||
best[cid] = (base_confidence, entry.ticker, alias_type, match_count)
|
||||
else:
|
||||
# Keep highest confidence, accumulate match count
|
||||
prev_conf, prev_ticker, prev_type, prev_count = existing
|
||||
if base_confidence > prev_conf:
|
||||
best[cid] = (base_confidence, entry.ticker, alias_type, prev_count + match_count)
|
||||
else:
|
||||
best[cid] = (prev_conf, prev_ticker, prev_type, prev_count + match_count)
|
||||
|
||||
mentions: list[dict[str, str | float | int]] = []
|
||||
for cid, (confidence, ticker, mention_type, count) in best.items():
|
||||
mentions.append({
|
||||
"company_id": cid,
|
||||
"ticker": ticker,
|
||||
"mention_type": mention_type,
|
||||
"confidence": confidence,
|
||||
"match_count": count,
|
||||
})
|
||||
|
||||
return mentions
|
||||
+108
-107
@@ -1,84 +1,41 @@
|
||||
"""Parser worker - HTML-to-text, boilerplate reduction, quality scoring."""
|
||||
"""Parser worker - HTML-to-text, boilerplate reduction, quality scoring.
|
||||
|
||||
Uses BeautifulSoup-based parsing pipeline for structured HTML extraction,
|
||||
metadata extraction, outbound link extraction, and quality scoring.
|
||||
Persists normalized text and structured parser output to MinIO,
|
||||
and updates document metadata in PostgreSQL.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 9.1, 9.2
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import asyncpg
|
||||
import httpx
|
||||
import redis.asyncio as aioredis
|
||||
from minio import Minio
|
||||
|
||||
from services.parser.html_parser import ParsedDocument, detect_company_mentions, parse_html
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_minio, get_pg_pool, get_redis
|
||||
from services.shared.logging import Span, extract_trace_context, inject_trace_context, new_trace_id, set_trace_context, setup_logging
|
||||
from services.shared.metrics import (
|
||||
ACTIVE_JOBS,
|
||||
PARSE_DURATION,
|
||||
PARSE_JOBS_TOTAL,
|
||||
PARSE_LOW_QUALITY_TOTAL,
|
||||
PARSE_QUALITY_SCORE,
|
||||
)
|
||||
from services.shared.metadata import update_document_parse_results
|
||||
from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_PARSING, queue_key
|
||||
from services.shared.storage import upload_normalized_text, upload_parser_output
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("parser_worker")
|
||||
|
||||
# Simple boilerplate patterns to strip
|
||||
BOILERPLATE_PATTERNS = [
|
||||
re.compile(r"(?i)subscribe to our newsletter.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)click here to read more.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)advertisement\s*\n"),
|
||||
re.compile(r"(?i)copyright ©.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)all rights reserved.*?(?:\n|$)"),
|
||||
re.compile(r"(?i)terms of (use|service).*?(?:\n|$)"),
|
||||
re.compile(r"(?i)privacy policy.*?(?:\n|$)"),
|
||||
re.compile(r"\s*\[.*?ad.*?\]\s*", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
def strip_html_tags(html: str) -> str:
|
||||
"""Basic HTML tag removal."""
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r"<[^>]+>", " ", text)
|
||||
text = re.sub(r" ", " ", text)
|
||||
text = re.sub(r"&", "&", text)
|
||||
text = re.sub(r"<", "<", text)
|
||||
text = re.sub(r">", ">", text)
|
||||
text = re.sub(r"&#\d+;", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def reduce_boilerplate(text: str) -> str:
|
||||
for pattern in BOILERPLATE_PATTERNS:
|
||||
text = pattern.sub("", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def score_quality(text: str) -> Tuple[float, str]:
|
||||
"""Score parse quality. Returns (score, confidence_label)."""
|
||||
word_count = len(text.split())
|
||||
if word_count < 20:
|
||||
return 0.1, "low"
|
||||
if word_count < 50:
|
||||
return 0.3, "low"
|
||||
if word_count < 150:
|
||||
return 0.6, "medium"
|
||||
return 0.85, "high"
|
||||
|
||||
|
||||
def detect_company_mentions(text: str, aliases: List[dict]) -> List[dict]:
|
||||
"""Detect company mentions using ticker, alias, and name matching."""
|
||||
mentions = []
|
||||
text_upper = text.upper()
|
||||
for alias_info in aliases:
|
||||
alias = alias_info["alias"]
|
||||
if alias.upper() in text_upper:
|
||||
mentions.append({
|
||||
"company_id": alias_info["company_id"],
|
||||
"ticker": alias_info.get("ticker", ""),
|
||||
"mention_type": alias_info.get("alias_type", "alias"),
|
||||
"confidence": 0.7,
|
||||
})
|
||||
return mentions
|
||||
|
||||
|
||||
async def fetch_html(url: str) -> Optional[str]:
|
||||
"""Fetch article HTML for scraping."""
|
||||
@@ -94,48 +51,65 @@ async def fetch_html(url: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def build_parser_output_json(parsed: ParsedDocument, mentions: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Build a structured JSON dict from ParsedDocument and detected mentions.
|
||||
|
||||
This captures the full parser output for audit and downstream use:
|
||||
metadata, quality signals, warnings, outbound links, tags, and mentions.
|
||||
"""
|
||||
return {
|
||||
"title": parsed.title,
|
||||
"author": parsed.author,
|
||||
"publisher": parsed.publisher,
|
||||
"published_at": parsed.published_at,
|
||||
"canonical_url": parsed.canonical_url,
|
||||
"language": parsed.language,
|
||||
"description": parsed.description,
|
||||
"document_type": parsed.document_type,
|
||||
"word_count": parsed.word_count,
|
||||
"outbound_links": parsed.outbound_links,
|
||||
"tags": parsed.tags,
|
||||
"quality_score": parsed.quality_score,
|
||||
"confidence": parsed.confidence,
|
||||
"low_quality_flag": parsed.low_quality_flag,
|
||||
"quality_warnings": parsed.quality_warnings,
|
||||
"quality_signals": parsed.quality_signals.as_dict(),
|
||||
"mentioned_companies": mentions,
|
||||
}
|
||||
|
||||
|
||||
async def process_job(
|
||||
job: dict,
|
||||
job: dict[str, Any],
|
||||
pool: asyncpg.Pool,
|
||||
rds: aioredis.Redis,
|
||||
minio_client: Minio,
|
||||
):
|
||||
) -> None:
|
||||
doc_id = job["document_id"]
|
||||
ticker = job["ticker"]
|
||||
url = job.get("url", "")
|
||||
now = datetime.now(timezone.utc)
|
||||
_parse_start = time.monotonic()
|
||||
|
||||
set_trace_context(trace_id=job.get("_trace_id") or new_trace_id())
|
||||
|
||||
# Fetch HTML if we have a URL
|
||||
html = await fetch_html(url) if url else None
|
||||
|
||||
if html:
|
||||
# Store raw HTML
|
||||
html_bytes = html.encode("utf-8")
|
||||
now = datetime.utcnow()
|
||||
html_path = f"scrape/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/raw.html"
|
||||
minio_client.put_object(
|
||||
"stonks-raw-news", html_path, io.BytesIO(html_bytes), len(html_bytes),
|
||||
content_type="text/html",
|
||||
)
|
||||
|
||||
# Parse
|
||||
text = strip_html_tags(html)
|
||||
text = reduce_boilerplate(text)
|
||||
# Parse using BeautifulSoup pipeline
|
||||
parsed = parse_html(html, url)
|
||||
else:
|
||||
text = ""
|
||||
parsed = ParsedDocument()
|
||||
|
||||
quality_score, confidence = score_quality(text)
|
||||
text = parsed.body_text
|
||||
|
||||
# Store normalized text
|
||||
# Upload normalized text to MinIO
|
||||
norm_ref: str | None = None
|
||||
if text:
|
||||
text_bytes = text.encode("utf-8")
|
||||
now = datetime.utcnow()
|
||||
norm_path = f"parsed/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/normalized.txt"
|
||||
minio_client.put_object(
|
||||
"stonks-normalized", norm_path, io.BytesIO(text_bytes), len(text_bytes),
|
||||
content_type="text/plain",
|
||||
norm_ref = upload_normalized_text(
|
||||
minio_client, ticker, doc_id,
|
||||
text.encode("utf-8"), timestamp=now,
|
||||
)
|
||||
else:
|
||||
norm_path = None
|
||||
|
||||
# Detect company mentions
|
||||
aliases = await pool.fetch(
|
||||
@@ -150,14 +124,24 @@ async def process_job(
|
||||
)
|
||||
mentions = detect_company_mentions(text, [dict(a) for a in aliases]) if text else []
|
||||
|
||||
# Update document
|
||||
status = "parsed" if confidence != "low" else "low_quality"
|
||||
await pool.execute(
|
||||
"""UPDATE documents SET
|
||||
normalized_storage_ref=$2, parse_quality_score=$3, parse_confidence=$4, status=$5, updated_at=NOW()
|
||||
WHERE id=$1""",
|
||||
doc_id, f"s3://stonks-normalized/{norm_path}" if norm_path else None,
|
||||
quality_score, confidence, status,
|
||||
# Build and upload structured parser output JSON
|
||||
output_json = build_parser_output_json(parsed, mentions)
|
||||
output_bytes = json.dumps(output_json, default=str, indent=2).encode("utf-8")
|
||||
parser_output_ref = upload_parser_output(
|
||||
minio_client, ticker, doc_id,
|
||||
output_bytes, timestamp=now,
|
||||
)
|
||||
|
||||
# Update document in PostgreSQL
|
||||
status = "parsed" if parsed.confidence != "low" else "low_quality"
|
||||
await update_document_parse_results(
|
||||
pool,
|
||||
document_id=doc_id,
|
||||
normalized_storage_ref=norm_ref,
|
||||
parser_output_ref=parser_output_ref,
|
||||
parse_quality_score=parsed.quality_score,
|
||||
parse_confidence=parsed.confidence,
|
||||
status=status,
|
||||
)
|
||||
|
||||
# Insert company mentions
|
||||
@@ -169,19 +153,36 @@ async def process_job(
|
||||
)
|
||||
|
||||
# Only enqueue for extraction if quality is acceptable
|
||||
if confidence != "low":
|
||||
await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps({
|
||||
if parsed.confidence != "low":
|
||||
await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps(inject_trace_context({
|
||||
"document_id": doc_id,
|
||||
"ticker": ticker,
|
||||
"normalized_text": text[:8000], # Truncate for prompt
|
||||
}))
|
||||
logger.info(f"Parsed doc {doc_id} for {ticker}: quality={quality_score:.2f}, confidence={confidence}")
|
||||
"normalized_text": text[:8000],
|
||||
})))
|
||||
PARSE_JOBS_TOTAL.labels(status="parsed").inc()
|
||||
PARSE_QUALITY_SCORE.observe(parsed.quality_score)
|
||||
PARSE_DURATION.observe(time.monotonic() - _parse_start)
|
||||
logger.info(
|
||||
"Parsed doc %s for %s: quality=%.2f, confidence=%s",
|
||||
doc_id, ticker, parsed.quality_score, parsed.confidence,
|
||||
extra={"ticker": ticker, "document_id": doc_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Low quality parse for doc {doc_id}, skipping extraction")
|
||||
PARSE_JOBS_TOTAL.labels(status="low_quality").inc()
|
||||
PARSE_LOW_QUALITY_TOTAL.inc()
|
||||
PARSE_QUALITY_SCORE.observe(parsed.quality_score)
|
||||
PARSE_DURATION.observe(time.monotonic() - _parse_start)
|
||||
logger.warning(
|
||||
"Low quality parse for doc %s, skipping extraction",
|
||||
doc_id,
|
||||
extra={"ticker": ticker, "document_id": doc_id},
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("parser_worker", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await get_pg_pool(config)
|
||||
rds = get_redis(config)
|
||||
minio_client = get_minio(config)
|
||||
@@ -197,7 +198,7 @@ async def main():
|
||||
try:
|
||||
await process_job(job, pool, rds, minio_client)
|
||||
except Exception as e:
|
||||
logger.error(f"Parse error: {e}")
|
||||
logger.error("Parse error: %s", e, exc_info=True)
|
||||
else:
|
||||
await asyncio.sleep(2)
|
||||
finally:
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
"""Deterministic recommendation eligibility logic.
|
||||
|
||||
Evaluates trend summaries against configurable thresholds to decide:
|
||||
- Whether a recommendation should be generated at all
|
||||
- What action type (buy/sell/hold/watch) is appropriate
|
||||
- What execution mode (informational/paper_eligible/live_eligible) is allowed
|
||||
- Position sizing guidance based on portfolio rules
|
||||
|
||||
All decisions are rule-based with no model involvement. The LLM is only
|
||||
used downstream for optional thesis wording (a separate task).
|
||||
|
||||
Requirements: 7.1, 7.2, 7.3, 7.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
from services.shared.schemas import (
|
||||
ActionType,
|
||||
PositionSizing,
|
||||
RecommendationMode,
|
||||
TrendDirection,
|
||||
TrendSummary,
|
||||
)
|
||||
|
||||
|
||||
class RejectionReason(str, Enum):
|
||||
"""Why a trend summary was deemed ineligible for a recommendation."""
|
||||
|
||||
LOW_CONFIDENCE = "low_confidence"
|
||||
LOW_TREND_STRENGTH = "low_trend_strength"
|
||||
HIGH_CONTRADICTION = "high_contradiction"
|
||||
INSUFFICIENT_EVIDENCE = "insufficient_evidence"
|
||||
NEUTRAL_DIRECTION = "neutral_direction"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EligibilityConfig:
|
||||
"""Tunable thresholds for recommendation eligibility.
|
||||
|
||||
All thresholds are deterministic — no model inference involved.
|
||||
"""
|
||||
|
||||
# --- Gate thresholds (below these → no recommendation) ---
|
||||
min_confidence: float = 0.35
|
||||
min_trend_strength: float = 0.10
|
||||
max_contradiction_score: float = 0.60
|
||||
min_evidence_count: int = 2 # combined supporting + opposing
|
||||
|
||||
# --- Action mapping thresholds ---
|
||||
# Trend strength above this → buy/sell; below → hold/watch
|
||||
action_strength_threshold: float = 0.25
|
||||
# Confidence above this → hold (rather than watch) for weak signals
|
||||
hold_confidence_threshold: float = 0.50
|
||||
|
||||
# --- Mode escalation thresholds ---
|
||||
# Confidence required for paper_eligible (below → informational)
|
||||
paper_confidence_threshold: float = 0.50
|
||||
# Confidence required for live_eligible (below → paper at most)
|
||||
live_confidence_threshold: float = 0.70
|
||||
# Contradiction must be below this for live eligibility
|
||||
live_max_contradiction: float = 0.25
|
||||
# Minimum evidence count for live eligibility
|
||||
live_min_evidence: int = 5
|
||||
|
||||
# --- Position sizing rules (Requirement 7.3) ---
|
||||
# Base portfolio allocation percentage
|
||||
base_portfolio_pct: float = 0.02
|
||||
# Maximum portfolio allocation percentage
|
||||
max_portfolio_pct: float = 0.05
|
||||
# Base max loss percentage
|
||||
base_max_loss_pct: float = 0.005
|
||||
# Maximum max loss percentage
|
||||
max_max_loss_pct: float = 0.01
|
||||
# Confidence scaling: higher confidence → larger position (linear)
|
||||
confidence_sizing_weight: float = 0.5
|
||||
# Contradiction penalty: higher contradiction → smaller position
|
||||
contradiction_sizing_penalty: float = 0.3
|
||||
|
||||
|
||||
DEFAULT_ELIGIBILITY_CONFIG = EligibilityConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EligibilityResult:
|
||||
"""Output of the deterministic eligibility evaluation.
|
||||
|
||||
Captures the decision, the reasoning, and all inputs used so the
|
||||
full decision trace is reproducible (Requirement 8.3).
|
||||
"""
|
||||
|
||||
eligible: bool
|
||||
action: ActionType
|
||||
mode: RecommendationMode
|
||||
position_sizing: PositionSizing
|
||||
rejection_reasons: list[RejectionReason] = field(default_factory=list)
|
||||
time_horizon: str = ""
|
||||
invalidation_conditions: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gate checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_gates(
|
||||
summary: TrendSummary,
|
||||
config: EligibilityConfig,
|
||||
) -> list[RejectionReason]:
|
||||
"""Apply hard gate checks. Returns a list of rejection reasons (empty = pass)."""
|
||||
reasons: list[RejectionReason] = []
|
||||
|
||||
if summary.confidence < config.min_confidence:
|
||||
reasons.append(RejectionReason.LOW_CONFIDENCE)
|
||||
|
||||
if summary.trend_strength < config.min_trend_strength:
|
||||
reasons.append(RejectionReason.LOW_TREND_STRENGTH)
|
||||
|
||||
if summary.contradiction_score > config.max_contradiction_score:
|
||||
reasons.append(RejectionReason.HIGH_CONTRADICTION)
|
||||
|
||||
evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
|
||||
if evidence_count < config.min_evidence_count:
|
||||
reasons.append(RejectionReason.INSUFFICIENT_EVIDENCE)
|
||||
|
||||
if summary.trend_direction == TrendDirection.NEUTRAL:
|
||||
reasons.append(RejectionReason.NEUTRAL_DIRECTION)
|
||||
|
||||
return reasons
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _determine_action(
|
||||
summary: TrendSummary,
|
||||
config: EligibilityConfig,
|
||||
) -> ActionType:
|
||||
"""Map trend direction and strength to an action type.
|
||||
|
||||
Strong bullish → BUY, strong bearish → SELL.
|
||||
Weak but directional → HOLD if confidence is decent, else WATCH.
|
||||
Mixed → WATCH.
|
||||
"""
|
||||
direction = summary.trend_direction
|
||||
strength = summary.trend_strength
|
||||
|
||||
if direction == TrendDirection.MIXED:
|
||||
return ActionType.WATCH
|
||||
|
||||
if direction == TrendDirection.NEUTRAL:
|
||||
return ActionType.WATCH
|
||||
|
||||
strong_signal = strength >= config.action_strength_threshold
|
||||
|
||||
if direction == TrendDirection.BULLISH:
|
||||
if strong_signal:
|
||||
return ActionType.BUY
|
||||
return ActionType.HOLD if summary.confidence >= config.hold_confidence_threshold else ActionType.WATCH
|
||||
|
||||
if direction == TrendDirection.BEARISH:
|
||||
if strong_signal:
|
||||
return ActionType.SELL
|
||||
return ActionType.HOLD if summary.confidence >= config.hold_confidence_threshold else ActionType.WATCH
|
||||
|
||||
return ActionType.WATCH
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mode escalation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _determine_mode(
|
||||
summary: TrendSummary,
|
||||
action: ActionType,
|
||||
config: EligibilityConfig,
|
||||
) -> RecommendationMode:
|
||||
"""Determine the highest execution mode allowed.
|
||||
|
||||
WATCH and HOLD actions are always informational — they don't trigger trades.
|
||||
BUY/SELL can escalate to paper_eligible or live_eligible based on
|
||||
confidence, contradiction, and evidence thresholds.
|
||||
"""
|
||||
if action in (ActionType.WATCH, ActionType.HOLD):
|
||||
return RecommendationMode.INFORMATIONAL
|
||||
|
||||
evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
|
||||
|
||||
# Check live eligibility first (strictest)
|
||||
if (
|
||||
summary.confidence >= config.live_confidence_threshold
|
||||
and summary.contradiction_score <= config.live_max_contradiction
|
||||
and evidence_count >= config.live_min_evidence
|
||||
):
|
||||
return RecommendationMode.LIVE_ELIGIBLE
|
||||
|
||||
# Check paper eligibility
|
||||
if summary.confidence >= config.paper_confidence_threshold:
|
||||
return RecommendationMode.PAPER_ELIGIBLE
|
||||
|
||||
return RecommendationMode.INFORMATIONAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Position sizing (Requirement 7.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_position_sizing(
|
||||
summary: TrendSummary,
|
||||
config: EligibilityConfig,
|
||||
) -> PositionSizing:
|
||||
"""Compute position sizing guidance from portfolio rules and signal quality.
|
||||
|
||||
Higher confidence → larger allocation (up to max).
|
||||
Higher contradiction → smaller allocation (penalty).
|
||||
"""
|
||||
# Start from base allocation
|
||||
confidence_scale = config.base_portfolio_pct + (
|
||||
config.confidence_sizing_weight
|
||||
* summary.confidence
|
||||
* (config.max_portfolio_pct - config.base_portfolio_pct)
|
||||
)
|
||||
|
||||
# Apply contradiction penalty
|
||||
contradiction_penalty = config.contradiction_sizing_penalty * summary.contradiction_score
|
||||
portfolio_pct = confidence_scale * (1.0 - contradiction_penalty)
|
||||
|
||||
# Clamp to bounds
|
||||
portfolio_pct = max(config.base_portfolio_pct * 0.5, min(portfolio_pct, config.max_portfolio_pct))
|
||||
|
||||
# Max loss scales similarly
|
||||
loss_scale = config.base_max_loss_pct + (
|
||||
config.confidence_sizing_weight
|
||||
* summary.confidence
|
||||
* (config.max_max_loss_pct - config.base_max_loss_pct)
|
||||
)
|
||||
max_loss_pct = loss_scale * (1.0 - contradiction_penalty)
|
||||
max_loss_pct = max(config.base_max_loss_pct * 0.5, min(max_loss_pct, config.max_max_loss_pct))
|
||||
|
||||
return PositionSizing(
|
||||
portfolio_pct=round(portfolio_pct, 6),
|
||||
max_loss_pct=round(max_loss_pct, 6),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Time horizon mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WINDOW_TO_HORIZON: dict[str, str] = {
|
||||
"intraday": "intraday",
|
||||
"1d": "swing_1d_3d",
|
||||
"7d": "swing_1d_10d",
|
||||
"30d": "position_10d_30d",
|
||||
"90d": "position_30d_90d",
|
||||
}
|
||||
|
||||
|
||||
def _map_time_horizon(window: str) -> str:
|
||||
"""Map a trend window to a human-readable time horizon label."""
|
||||
return _WINDOW_TO_HORIZON.get(window, f"window_{window}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Invalidation conditions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _derive_invalidation_conditions(
|
||||
summary: TrendSummary,
|
||||
action: ActionType,
|
||||
) -> list[str]:
|
||||
"""Generate deterministic invalidation conditions for the recommendation.
|
||||
|
||||
These describe when the recommendation should be considered stale or wrong.
|
||||
"""
|
||||
conditions: list[str] = []
|
||||
|
||||
if action == ActionType.BUY:
|
||||
conditions.append(
|
||||
f"Trend direction for {summary.entity_id} reverses to bearish"
|
||||
)
|
||||
elif action == ActionType.SELL:
|
||||
conditions.append(
|
||||
f"Trend direction for {summary.entity_id} reverses to bullish"
|
||||
)
|
||||
|
||||
if summary.contradiction_score > 0.0:
|
||||
conditions.append(
|
||||
f"Contradiction score exceeds 0.60 (currently {summary.contradiction_score:.2f})"
|
||||
)
|
||||
|
||||
if summary.confidence > 0.0:
|
||||
conditions.append(
|
||||
f"Confidence drops below {summary.confidence * 0.7:.2f}"
|
||||
)
|
||||
|
||||
if summary.material_risks:
|
||||
conditions.append(
|
||||
f"Material risk materialises: {summary.material_risks[0]}"
|
||||
)
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def evaluate_eligibility(
|
||||
summary: TrendSummary,
|
||||
config: EligibilityConfig = DEFAULT_ELIGIBILITY_CONFIG,
|
||||
) -> EligibilityResult:
|
||||
"""Evaluate a trend summary for recommendation eligibility.
|
||||
|
||||
This is the single deterministic entry point. It:
|
||||
1. Applies gate checks (confidence, strength, contradiction, evidence)
|
||||
2. Maps trend direction + strength to an action type
|
||||
3. Determines the highest allowed execution mode
|
||||
4. Computes position sizing from portfolio rules
|
||||
5. Derives invalidation conditions
|
||||
|
||||
Returns an EligibilityResult with the full decision trace.
|
||||
"""
|
||||
rejection_reasons = _check_gates(summary, config)
|
||||
|
||||
# Even if rejected, we still compute action/mode for the trace
|
||||
action = _determine_action(summary, config)
|
||||
mode = _determine_mode(summary, action, config)
|
||||
sizing = _compute_position_sizing(summary, config)
|
||||
horizon = _map_time_horizon(summary.window.value)
|
||||
invalidation = _derive_invalidation_conditions(summary, action)
|
||||
|
||||
eligible = len(rejection_reasons) == 0
|
||||
|
||||
# If not eligible, force mode to informational (Requirement 7.4)
|
||||
if not eligible:
|
||||
mode = RecommendationMode.INFORMATIONAL
|
||||
|
||||
return EligibilityResult(
|
||||
eligible=eligible,
|
||||
action=action,
|
||||
mode=mode,
|
||||
position_sizing=sizing,
|
||||
rejection_reasons=rejection_reasons,
|
||||
time_horizon=horizon,
|
||||
invalidation_conditions=invalidation,
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Recommendation worker entrypoint - polls Redis for recommendation jobs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.recommendation.worker import generate_recommendation
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import QUEUE_RECOMMENDATION, queue_key
|
||||
|
||||
logger = logging.getLogger("recommendation_main")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("recommendation", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||||
minio_client = Minio(
|
||||
config.minio.endpoint,
|
||||
access_key=config.minio.access_key,
|
||||
secret_key=config.minio.secret_key,
|
||||
secure=config.minio.secure,
|
||||
)
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_RECOMMENDATION)
|
||||
logger.info("Recommendation worker started, polling %s", queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
payload = raw
|
||||
job = json.loads(payload)
|
||||
ticker = job.get("ticker", "")
|
||||
window = job.get("window", "7d")
|
||||
|
||||
logger.info("Processing recommendation job for %s/%s", ticker, window)
|
||||
|
||||
try:
|
||||
rec = await generate_recommendation(
|
||||
pool, ticker, window,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
if rec:
|
||||
logger.info(
|
||||
"Recommendation generated for %s: %s %s",
|
||||
ticker, rec.action.value, rec.mode.value,
|
||||
)
|
||||
else:
|
||||
logger.info("No recommendation generated for %s (no trend data)", ticker)
|
||||
except Exception:
|
||||
logger.exception("Recommendation failed for %s", ticker)
|
||||
finally:
|
||||
await pool.close()
|
||||
await redis_client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Suppression logic for low-quality data or low confidence.
|
||||
|
||||
Evaluates the quality of the underlying data feeding a trend summary
|
||||
and suppresses automated trade eligibility when data quality is poor.
|
||||
Suppressed recommendations are marked as informational only.
|
||||
|
||||
This layer runs *before* the eligibility engine and acts as a pre-filter
|
||||
on data quality. The eligibility engine handles signal-level thresholds
|
||||
(confidence, strength, contradiction); this module handles data-level
|
||||
quality concerns (stale evidence, low extraction quality, poor source
|
||||
diversity, insufficient valid documents).
|
||||
|
||||
Requirements: 7.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from services.shared.schemas import TrendSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuppressionReason(str, Enum):
|
||||
"""Why a recommendation was suppressed due to data quality."""
|
||||
|
||||
LOW_DATA_CONFIDENCE = "low_data_confidence"
|
||||
STALE_EVIDENCE = "stale_evidence"
|
||||
LOW_SOURCE_DIVERSITY = "low_source_diversity"
|
||||
HIGH_EXTRACTION_FAILURE_RATE = "high_extraction_failure_rate"
|
||||
INSUFFICIENT_VALID_DOCUMENTS = "insufficient_valid_documents"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SuppressionConfig:
|
||||
"""Tunable thresholds for data quality suppression.
|
||||
|
||||
These thresholds focus on the quality of the *input data* rather
|
||||
than the trend signal itself (which is handled by EligibilityConfig).
|
||||
"""
|
||||
|
||||
# Minimum average extraction confidence across evidence documents.
|
||||
# Below this, the underlying data is too unreliable for trade decisions.
|
||||
min_avg_extraction_confidence: float = 0.40
|
||||
|
||||
# Maximum age (hours) of the most recent evidence document.
|
||||
# If the freshest evidence is older than this, the trend is stale.
|
||||
max_evidence_staleness_hours: float = 168.0 # 7 days
|
||||
|
||||
# Minimum number of distinct source types (e.g. news, filings, market)
|
||||
# represented in the evidence. Low diversity means the signal may be
|
||||
# driven by a single unreliable source class.
|
||||
min_source_types: int = 1
|
||||
|
||||
# Maximum tolerable extraction failure rate (0-1).
|
||||
# If more than this fraction of documents failed extraction,
|
||||
# the data pipeline is unreliable for this ticker.
|
||||
max_extraction_failure_rate: float = 0.50
|
||||
|
||||
# Minimum number of valid (non-failed) documents that contributed
|
||||
# to the trend. Below this, there isn't enough data to act on.
|
||||
min_valid_documents: int = 2
|
||||
|
||||
# Overall data quality confidence threshold.
|
||||
# The computed data quality score must exceed this for the
|
||||
# recommendation to be eligible for automated trading.
|
||||
min_data_quality_score: float = 0.30
|
||||
|
||||
|
||||
DEFAULT_SUPPRESSION_CONFIG = SuppressionConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataQualityContext:
|
||||
"""Quality metrics about the data underlying a trend summary.
|
||||
|
||||
Populated by querying document and extraction metadata for the
|
||||
ticker and window. When not available from the database, callers
|
||||
can construct this from the trend summary itself.
|
||||
"""
|
||||
|
||||
total_documents: int = 0
|
||||
valid_documents: int = 0
|
||||
failed_documents: int = 0
|
||||
avg_extraction_confidence: float = 0.0
|
||||
newest_evidence_at: datetime | None = None
|
||||
source_types: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuppressionResult:
|
||||
"""Output of the suppression evaluation."""
|
||||
|
||||
suppressed: bool
|
||||
reasons: list[SuppressionReason] = field(default_factory=list)
|
||||
data_quality_score: float = 0.0
|
||||
context: DataQualityContext | None = None
|
||||
|
||||
|
||||
def build_quality_context_from_summary(
|
||||
summary: TrendSummary,
|
||||
) -> DataQualityContext:
|
||||
"""Build a minimal DataQualityContext from a TrendSummary.
|
||||
|
||||
This is a fallback when full document-level quality metrics aren't
|
||||
available. It uses the trend summary's evidence counts and confidence
|
||||
as proxies.
|
||||
"""
|
||||
total = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
|
||||
return DataQualityContext(
|
||||
total_documents=total,
|
||||
valid_documents=total,
|
||||
failed_documents=0,
|
||||
avg_extraction_confidence=summary.confidence,
|
||||
newest_evidence_at=summary.generated_at,
|
||||
source_types=set(),
|
||||
)
|
||||
|
||||
|
||||
def _compute_data_quality_score(
|
||||
ctx: DataQualityContext,
|
||||
config: SuppressionConfig,
|
||||
reference_time: datetime,
|
||||
) -> float:
|
||||
"""Compute an overall data quality score from the context.
|
||||
|
||||
Returns a value in [0, 1] where higher is better quality.
|
||||
Components:
|
||||
- Extraction confidence (40% weight)
|
||||
- Evidence freshness (30% weight)
|
||||
- Document coverage (30% weight)
|
||||
"""
|
||||
# Extraction confidence component
|
||||
conf_component = min(ctx.avg_extraction_confidence / 0.8, 1.0)
|
||||
|
||||
# Freshness component
|
||||
if ctx.newest_evidence_at is not None:
|
||||
if ctx.newest_evidence_at.tzinfo is None:
|
||||
newest = ctx.newest_evidence_at.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
newest = ctx.newest_evidence_at
|
||||
age_hours = (reference_time - newest).total_seconds() / 3600.0
|
||||
max_hours = config.max_evidence_staleness_hours
|
||||
freshness_component = max(0.0, 1.0 - (age_hours / max_hours))
|
||||
else:
|
||||
freshness_component = 0.0
|
||||
|
||||
# Document coverage component
|
||||
if ctx.total_documents > 0:
|
||||
valid_ratio = ctx.valid_documents / ctx.total_documents
|
||||
count_factor = min(ctx.valid_documents / 10.0, 1.0)
|
||||
coverage_component = valid_ratio * count_factor
|
||||
else:
|
||||
coverage_component = 0.0
|
||||
|
||||
score = (0.4 * conf_component) + (0.3 * freshness_component) + (0.3 * coverage_component)
|
||||
return round(max(0.0, min(1.0, score)), 4)
|
||||
|
||||
|
||||
def evaluate_suppression(
|
||||
summary: TrendSummary,
|
||||
quality_ctx: DataQualityContext | None = None,
|
||||
config: SuppressionConfig = DEFAULT_SUPPRESSION_CONFIG,
|
||||
reference_time: datetime | None = None,
|
||||
) -> SuppressionResult:
|
||||
"""Evaluate whether a recommendation should be suppressed due to data quality.
|
||||
|
||||
Checks multiple data quality dimensions and returns a SuppressionResult
|
||||
indicating whether the recommendation should be suppressed and why.
|
||||
|
||||
Args:
|
||||
summary: The trend summary to evaluate.
|
||||
quality_ctx: Data quality context. If None, a minimal context is
|
||||
built from the trend summary itself.
|
||||
config: Suppression thresholds.
|
||||
reference_time: Reference time for staleness checks.
|
||||
|
||||
Returns:
|
||||
SuppressionResult with suppression decision and reasons.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
ctx = quality_ctx or build_quality_context_from_summary(summary)
|
||||
reasons: list[SuppressionReason] = []
|
||||
|
||||
# Check average extraction confidence
|
||||
if ctx.avg_extraction_confidence < config.min_avg_extraction_confidence:
|
||||
reasons.append(SuppressionReason.LOW_DATA_CONFIDENCE)
|
||||
|
||||
# Check evidence staleness
|
||||
if ctx.newest_evidence_at is not None:
|
||||
newest = ctx.newest_evidence_at
|
||||
if newest.tzinfo is None:
|
||||
newest = newest.replace(tzinfo=timezone.utc)
|
||||
age_hours = (reference_time - newest).total_seconds() / 3600.0
|
||||
if age_hours > config.max_evidence_staleness_hours:
|
||||
reasons.append(SuppressionReason.STALE_EVIDENCE)
|
||||
elif ctx.total_documents > 0:
|
||||
# Have documents but no timestamp — treat as stale
|
||||
reasons.append(SuppressionReason.STALE_EVIDENCE)
|
||||
|
||||
# Check source diversity
|
||||
if len(ctx.source_types) < config.min_source_types and ctx.total_documents > 0:
|
||||
reasons.append(SuppressionReason.LOW_SOURCE_DIVERSITY)
|
||||
|
||||
# Check extraction failure rate
|
||||
if ctx.total_documents > 0:
|
||||
failure_rate = ctx.failed_documents / ctx.total_documents
|
||||
if failure_rate > config.max_extraction_failure_rate:
|
||||
reasons.append(SuppressionReason.HIGH_EXTRACTION_FAILURE_RATE)
|
||||
|
||||
# Check minimum valid documents
|
||||
if ctx.valid_documents < config.min_valid_documents:
|
||||
reasons.append(SuppressionReason.INSUFFICIENT_VALID_DOCUMENTS)
|
||||
|
||||
# Compute overall data quality score
|
||||
quality_score = _compute_data_quality_score(ctx, config, reference_time)
|
||||
|
||||
# If quality score is below threshold, add a general suppression reason
|
||||
if quality_score < config.min_data_quality_score and SuppressionReason.LOW_DATA_CONFIDENCE not in reasons:
|
||||
reasons.append(SuppressionReason.LOW_DATA_CONFIDENCE)
|
||||
|
||||
suppressed = len(reasons) > 0
|
||||
|
||||
if suppressed:
|
||||
logger.info(
|
||||
"Recommendation suppressed for %s/%s: reasons=%s quality_score=%.3f",
|
||||
summary.entity_id, summary.window.value,
|
||||
[r.value for r in reasons], quality_score,
|
||||
)
|
||||
|
||||
return SuppressionResult(
|
||||
suppressed=suppressed,
|
||||
reasons=reasons,
|
||||
data_quality_score=quality_score,
|
||||
context=ctx,
|
||||
)
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Optional LLM wording layer for thesis generation.
|
||||
|
||||
Takes a deterministic thesis string (built from trend data and eligibility
|
||||
rules) and rewrites it into natural, analyst-quality prose using a local
|
||||
Ollama model. The deterministic thesis is always preserved as the fallback
|
||||
and audit reference.
|
||||
|
||||
This module is opt-in: callers must explicitly request LLM rewriting.
|
||||
If the LLM call fails or is disabled, the original deterministic thesis
|
||||
is returned unchanged.
|
||||
|
||||
Requirements: 7.1, 7.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
from services.shared.config import OllamaConfig
|
||||
from services.shared.schemas import TrendSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
THESIS_PROMPT_VERSION = "thesis-rewrite-v1"
|
||||
|
||||
THESIS_SYSTEM_PROMPT = """\
|
||||
You are a concise financial analyst. You rewrite structured trade thesis \
|
||||
summaries into clear, professional prose suitable for an internal research note.
|
||||
|
||||
STRICT RULES:
|
||||
1. Do NOT add any information that is not present in the input.
|
||||
2. Do NOT fabricate numbers, dates, company names, or analyst opinions.
|
||||
3. Keep the rewrite under 150 words.
|
||||
4. Preserve all factual claims, risk notes, and evidence counts from the input.
|
||||
5. Use a neutral, professional tone. Avoid hype or marketing language.
|
||||
6. Return ONLY the rewritten thesis text. No JSON, no markdown, no commentary."""
|
||||
|
||||
|
||||
def build_thesis_rewrite_prompt(
|
||||
deterministic_thesis: str,
|
||||
summary: TrendSummary,
|
||||
) -> dict[str, str]:
|
||||
"""Build system and user prompts for thesis rewriting.
|
||||
|
||||
Provides the model with the deterministic thesis and key trend
|
||||
context so it can produce a natural-language version.
|
||||
"""
|
||||
context_parts = [
|
||||
f"Ticker: {summary.entity_id}",
|
||||
f"Window: {summary.window.value}",
|
||||
f"Direction: {summary.trend_direction.value}",
|
||||
f"Strength: {summary.trend_strength:.2f}",
|
||||
f"Confidence: {summary.confidence:.2f}",
|
||||
f"Contradiction score: {summary.contradiction_score:.2f}",
|
||||
]
|
||||
if summary.dominant_catalysts:
|
||||
context_parts.append(f"Catalysts: {', '.join(summary.dominant_catalysts[:3])}")
|
||||
if summary.material_risks:
|
||||
context_parts.append(f"Risks: {'; '.join(summary.material_risks[:2])}")
|
||||
|
||||
context_block = "\n".join(context_parts)
|
||||
|
||||
user_prompt = f"""\
|
||||
Rewrite the following structured thesis into clear, professional analyst prose.
|
||||
|
||||
--- STRUCTURED THESIS ---
|
||||
{deterministic_thesis}
|
||||
--- END STRUCTURED THESIS ---
|
||||
|
||||
--- CONTEXT ---
|
||||
{context_block}
|
||||
--- END CONTEXT ---
|
||||
|
||||
Return ONLY the rewritten thesis. No other text."""
|
||||
|
||||
return {
|
||||
"system": THESIS_SYSTEM_PROMPT,
|
||||
"user": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
async def rewrite_thesis_with_llm(
|
||||
deterministic_thesis: str,
|
||||
summary: TrendSummary,
|
||||
config: OllamaConfig,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> str:
|
||||
"""Rewrite a deterministic thesis using a local Ollama model.
|
||||
|
||||
If the LLM call fails for any reason, returns the original
|
||||
deterministic thesis unchanged. This ensures the LLM layer is
|
||||
purely additive and never blocks recommendation generation.
|
||||
|
||||
Args:
|
||||
deterministic_thesis: The rule-based thesis string.
|
||||
summary: The trend summary that produced the thesis.
|
||||
config: Ollama connection and model configuration.
|
||||
http_client: Optional shared HTTP client for connection reuse.
|
||||
|
||||
Returns:
|
||||
The LLM-rewritten thesis on success, or the original on failure.
|
||||
"""
|
||||
prompts = build_thesis_rewrite_prompt(deterministic_thesis, summary)
|
||||
|
||||
owns_client = http_client is None
|
||||
client = http_client or httpx.AsyncClient(timeout=config.timeout)
|
||||
|
||||
try:
|
||||
rewritten = await _call_ollama_thesis(client, config, prompts)
|
||||
if rewritten:
|
||||
logger.info(
|
||||
"LLM thesis rewrite succeeded for %s (%d chars → %d chars)",
|
||||
summary.entity_id,
|
||||
len(deterministic_thesis),
|
||||
len(rewritten),
|
||||
)
|
||||
return rewritten
|
||||
|
||||
logger.warning(
|
||||
"LLM thesis rewrite returned empty for %s — using deterministic thesis",
|
||||
summary.entity_id,
|
||||
)
|
||||
return deterministic_thesis
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"LLM thesis rewrite failed for %s — using deterministic thesis",
|
||||
summary.entity_id,
|
||||
)
|
||||
return deterministic_thesis
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def _call_ollama_thesis(
|
||||
client: httpx.AsyncClient,
|
||||
config: OllamaConfig,
|
||||
prompts: dict[str, str],
|
||||
) -> str:
|
||||
"""Make a single Ollama chat call for thesis rewriting.
|
||||
|
||||
Returns the model's text response, or empty string on failure.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
payload = {
|
||||
"model": config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": prompts["system"]},
|
||||
{"role": "user", "content": prompts["user"]},
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
resp = await client.post(
|
||||
f"{config.base_url}/api/chat",
|
||||
json=payload,
|
||||
)
|
||||
_ = resp.raise_for_status()
|
||||
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
body: dict[str, object] = resp.json()
|
||||
msg = body.get("message")
|
||||
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
|
||||
|
||||
logger.debug(
|
||||
"Ollama thesis call completed in %dms, response length=%d",
|
||||
duration_ms,
|
||||
len(content),
|
||||
)
|
||||
|
||||
return content.strip()
|
||||
@@ -1 +1,721 @@
|
||||
"""Recommendation worker - generates explainable trade recommendations from trend data."""
|
||||
"""Recommendation worker - generates explainable trade recommendations from trend data.
|
||||
|
||||
Fetches the latest trend summaries for a ticker, evaluates eligibility
|
||||
using deterministic rules, builds Recommendation objects with thesis
|
||||
and evidence citations, and persists them to PostgreSQL.
|
||||
|
||||
Requirements: 7.1, 7.2, 7.3, 7.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.recommendation.eligibility import (
|
||||
EligibilityConfig,
|
||||
EligibilityResult,
|
||||
evaluate_eligibility,
|
||||
)
|
||||
from services.recommendation.suppression import (
|
||||
DataQualityContext,
|
||||
SuppressionConfig,
|
||||
SuppressionResult,
|
||||
evaluate_suppression,
|
||||
)
|
||||
from services.recommendation.thesis_llm import (
|
||||
THESIS_PROMPT_VERSION,
|
||||
rewrite_thesis_with_llm,
|
||||
)
|
||||
from minio import Minio
|
||||
|
||||
from services.lake_publisher.worker import publish_recommendation_facts
|
||||
from services.shared.config import OllamaConfig
|
||||
from services.shared.schemas import (
|
||||
ModelMetadata,
|
||||
PositionSizing,
|
||||
Recommendation,
|
||||
RecommendationMode,
|
||||
TrendDirection,
|
||||
TrendSummary,
|
||||
TrendWindow,
|
||||
)
|
||||
from services.shared.metrics import (
|
||||
RECOMMENDATION_CONFIDENCE,
|
||||
RECOMMENDATION_GENERATED,
|
||||
RECOMMENDATION_SUPPRESSED,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch latest trend summary for a ticker + window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch data quality context for suppression checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DATA_QUALITY_QUERY = """
|
||||
SELECT
|
||||
COUNT(*) AS total_documents,
|
||||
COUNT(*) FILTER (WHERE di.validation_status = 'valid') AS valid_documents,
|
||||
COUNT(*) FILTER (WHERE di.validation_status = 'failed') AS failed_documents,
|
||||
AVG(di.confidence) FILTER (WHERE di.validation_status = 'valid') AS avg_extraction_confidence,
|
||||
MAX(d.published_at) AS newest_evidence_at,
|
||||
ARRAY_AGG(DISTINCT s.source_class) FILTER (WHERE s.source_class IS NOT NULL) AS source_types
|
||||
FROM documents d
|
||||
JOIN document_intelligence di ON di.document_id = d.id
|
||||
LEFT JOIN sources s ON d.source_id = s.id
|
||||
WHERE d.id = ANY(
|
||||
SELECT UNNEST(
|
||||
COALESCE(tw.top_supporting_evidence, '[]'::jsonb)
|
||||
|| COALESCE(tw.top_opposing_evidence, '[]'::jsonb)
|
||||
)::uuid
|
||||
FROM trend_windows tw
|
||||
WHERE tw.entity_id = $1 AND tw.window = $2
|
||||
ORDER BY tw.generated_at DESC
|
||||
LIMIT 1
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_data_quality_context(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window: str,
|
||||
) -> DataQualityContext | None:
|
||||
"""Fetch data quality metrics for the documents underlying a trend.
|
||||
|
||||
Returns None if the query fails or returns no data, in which case
|
||||
the suppression module will fall back to summary-based estimation.
|
||||
"""
|
||||
try:
|
||||
row = await pool.fetchrow(_DATA_QUALITY_QUERY, ticker, window)
|
||||
if row is None or row["total_documents"] == 0:
|
||||
return None
|
||||
|
||||
source_types_raw = row["source_types"]
|
||||
source_types = set(source_types_raw) if source_types_raw else set()
|
||||
|
||||
return DataQualityContext(
|
||||
total_documents=int(row["total_documents"]),
|
||||
valid_documents=int(row["valid_documents"] or 0),
|
||||
failed_documents=int(row["failed_documents"] or 0),
|
||||
avg_extraction_confidence=float(row["avg_extraction_confidence"] or 0.0),
|
||||
newest_evidence_at=row["newest_evidence_at"],
|
||||
source_types=source_types,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch data quality context for %s/%s — will use summary fallback",
|
||||
ticker, window, exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
_LATEST_TREND_QUERY = """
|
||||
SELECT
|
||||
entity_type, entity_id, window, trend_direction, trend_strength,
|
||||
confidence, top_supporting_evidence, top_opposing_evidence,
|
||||
dominant_catalysts, material_risks, contradiction_score,
|
||||
disagreement_details, market_context, generated_at
|
||||
FROM trend_windows
|
||||
WHERE entity_id = $1 AND window = $2
|
||||
ORDER BY generated_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
def _parse_trend_row(row: asyncpg.Record) -> TrendSummary:
|
||||
"""Convert a trend_windows row into a TrendSummary."""
|
||||
supporting = row["top_supporting_evidence"]
|
||||
if isinstance(supporting, str):
|
||||
supporting = json.loads(supporting)
|
||||
|
||||
opposing = row["top_opposing_evidence"]
|
||||
if isinstance(opposing, str):
|
||||
opposing = json.loads(opposing)
|
||||
|
||||
catalysts = row["dominant_catalysts"]
|
||||
if isinstance(catalysts, str):
|
||||
catalysts = json.loads(catalysts)
|
||||
|
||||
risks = row["material_risks"]
|
||||
if isinstance(risks, str):
|
||||
risks = json.loads(risks)
|
||||
|
||||
return TrendSummary(
|
||||
entity_type=row["entity_type"],
|
||||
entity_id=row["entity_id"],
|
||||
window=TrendWindow(row["window"]),
|
||||
trend_direction=TrendDirection(row["trend_direction"]),
|
||||
trend_strength=float(row["trend_strength"]),
|
||||
confidence=float(row["confidence"]),
|
||||
top_supporting_evidence=supporting or [],
|
||||
top_opposing_evidence=opposing or [],
|
||||
dominant_catalysts=catalysts or [],
|
||||
material_risks=risks or [],
|
||||
contradiction_score=float(row["contradiction_score"] or 0.0),
|
||||
generated_at=row["generated_at"],
|
||||
)
|
||||
|
||||
|
||||
async def fetch_latest_trend(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window: str,
|
||||
) -> TrendSummary | None:
|
||||
"""Fetch the most recent trend summary for a ticker and window."""
|
||||
row = await pool.fetchrow(_LATEST_TREND_QUERY, ticker, window)
|
||||
if row is None:
|
||||
return None
|
||||
return _parse_trend_row(row)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build thesis from trend summary (deterministic, no LLM)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_thesis(
|
||||
summary: TrendSummary,
|
||||
result: EligibilityResult,
|
||||
) -> str:
|
||||
"""Generate a deterministic thesis string from trend data.
|
||||
|
||||
This is the descriptive analysis portion (Requirement 7.2).
|
||||
The LLM wording layer is a separate optional task.
|
||||
"""
|
||||
direction = summary.trend_direction.value
|
||||
ticker = summary.entity_id
|
||||
window = summary.window.value
|
||||
strength = summary.trend_strength
|
||||
confidence = summary.confidence
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
# Opening: direction and strength
|
||||
parts.append(
|
||||
f"{ticker} shows a {direction} trend over the {window} window "
|
||||
+ f"with strength {strength:.2f} and confidence {confidence:.2f}."
|
||||
)
|
||||
|
||||
# Catalysts
|
||||
if summary.dominant_catalysts:
|
||||
catalyst_str = ", ".join(summary.dominant_catalysts[:3])
|
||||
parts.append(f"Dominant catalysts: {catalyst_str}.")
|
||||
|
||||
# Contradiction note (Requirement 7.2 — separate descriptive from prescriptive)
|
||||
if summary.contradiction_score > 0.15:
|
||||
parts.append(
|
||||
"Notable signal disagreement detected "
|
||||
+ f"(contradiction score: {summary.contradiction_score:.2f})."
|
||||
)
|
||||
|
||||
# Risks
|
||||
if summary.material_risks:
|
||||
risk_str = "; ".join(summary.material_risks[:2])
|
||||
parts.append(f"Key risks: {risk_str}.")
|
||||
|
||||
# Evidence count
|
||||
supporting_count = len(summary.top_supporting_evidence)
|
||||
opposing_count = len(summary.top_opposing_evidence)
|
||||
parts.append(
|
||||
f"Based on {supporting_count} supporting and "
|
||||
+ f"{opposing_count} opposing evidence documents."
|
||||
)
|
||||
|
||||
# Prescriptive action (separated per Requirement 7.2)
|
||||
action = result.action.value.upper()
|
||||
mode = result.mode.value.replace("_", " ")
|
||||
parts.append(f"Recommendation: {action} ({mode}).")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build risk classification (Requirement 7.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def classify_risk(
|
||||
summary: TrendSummary,
|
||||
result: EligibilityResult,
|
||||
) -> str:
|
||||
"""Assign a risk classification label based on signal quality.
|
||||
|
||||
Returns one of: low, moderate, high, very_high.
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# Contradiction raises risk
|
||||
score += summary.contradiction_score * 2.0
|
||||
|
||||
# Low confidence raises risk
|
||||
score += (1.0 - summary.confidence) * 1.5
|
||||
|
||||
# Low evidence count raises risk
|
||||
evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
|
||||
if evidence_count < 3:
|
||||
score += 1.0
|
||||
elif evidence_count < 5:
|
||||
score += 0.5
|
||||
|
||||
# Rejection reasons raise risk
|
||||
score += len(result.rejection_reasons) * 0.5
|
||||
|
||||
if score >= 3.0:
|
||||
return "very_high"
|
||||
if score >= 2.0:
|
||||
return "high"
|
||||
if score >= 1.0:
|
||||
return "moderate"
|
||||
return "low"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build Recommendation from eligibility result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_recommendation(
|
||||
summary: TrendSummary,
|
||||
result: EligibilityResult,
|
||||
reference_time: datetime | None = None,
|
||||
llm_thesis: str | None = None,
|
||||
suppression_result: SuppressionResult | None = None,
|
||||
) -> Recommendation:
|
||||
"""Assemble a Recommendation object from a trend summary and eligibility result.
|
||||
|
||||
Combines all evidence refs (supporting + opposing) into the recommendation
|
||||
so the full decision trace is available (Requirement 8.3).
|
||||
|
||||
If ``llm_thesis`` is provided (from the optional LLM wording layer),
|
||||
it replaces the deterministic thesis text while preserving the risk
|
||||
classification prefix.
|
||||
|
||||
If ``suppression_result`` indicates suppression, a suppression note
|
||||
is appended to the thesis for audit visibility (Requirement 7.4).
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
# Combine evidence refs — supporting first, then opposing
|
||||
evidence_refs = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence)
|
||||
|
||||
deterministic_thesis = build_thesis(summary, result)
|
||||
risk_class = classify_risk(summary, result)
|
||||
|
||||
# Use LLM-rewritten thesis if available, otherwise deterministic
|
||||
thesis_body = llm_thesis if llm_thesis else deterministic_thesis
|
||||
|
||||
# Append suppression note if suppressed (Requirement 7.4)
|
||||
if suppression_result and suppression_result.suppressed:
|
||||
reason_strs = [r.value for r in suppression_result.reasons]
|
||||
thesis_body += (
|
||||
f" [SUPPRESSED: data quality below threshold "
|
||||
f"(score={suppression_result.data_quality_score:.2f}, "
|
||||
f"reasons={', '.join(reason_strs)})]"
|
||||
)
|
||||
|
||||
# Track whether the thesis was LLM-generated for audit
|
||||
if llm_thesis:
|
||||
provider = "ollama"
|
||||
model_name = "thesis-rewrite"
|
||||
prompt_version = THESIS_PROMPT_VERSION
|
||||
else:
|
||||
provider = "deterministic"
|
||||
model_name = "eligibility-v1"
|
||||
prompt_version = ""
|
||||
|
||||
return Recommendation(
|
||||
ticker=summary.entity_id,
|
||||
action=result.action,
|
||||
mode=result.mode,
|
||||
confidence=summary.confidence,
|
||||
time_horizon=result.time_horizon,
|
||||
thesis=f"[risk:{risk_class}] {thesis_body}",
|
||||
invalidation_conditions=result.invalidation_conditions,
|
||||
position_sizing=PositionSizing(
|
||||
portfolio_pct=result.position_sizing.portfolio_pct,
|
||||
max_loss_pct=result.position_sizing.max_loss_pct,
|
||||
),
|
||||
evidence_refs=evidence_refs,
|
||||
model_metadata=ModelMetadata(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt_version=prompt_version,
|
||||
schema_version="1.0.0",
|
||||
),
|
||||
generated_at=reference_time,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persist recommendation to PostgreSQL
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_RECOMMENDATION = """
|
||||
INSERT INTO recommendations (
|
||||
ticker, action, mode, confidence, time_horizon,
|
||||
thesis, invalidation_conditions, portfolio_pct, max_loss_pct,
|
||||
model_version, model_provider, prompt_version, schema_version,
|
||||
risk_classification, generated_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7::jsonb, $8, $9,
|
||||
$10, $11, $12, $13,
|
||||
$14, $15
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
_INSERT_REC_EVIDENCE = """
|
||||
INSERT INTO recommendation_evidence (
|
||||
recommendation_id, document_id, evidence_type, weight
|
||||
) VALUES ($1, $2::uuid, $3, $4)
|
||||
"""
|
||||
|
||||
_INSERT_RISK_EVALUATION = """
|
||||
INSERT INTO risk_evaluations (
|
||||
recommendation_id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at
|
||||
) VALUES ($1::uuid, $2, $3, $4::jsonb, $5::jsonb, $6)
|
||||
"""
|
||||
|
||||
_FETCH_RECOMMENDATION = """
|
||||
SELECT
|
||||
id, ticker, action, mode, confidence, time_horizon,
|
||||
thesis, invalidation_conditions, portfolio_pct, max_loss_pct,
|
||||
model_version, model_provider, prompt_version, schema_version,
|
||||
risk_classification, generated_at
|
||||
FROM recommendations
|
||||
WHERE id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_REC_EVIDENCE = """
|
||||
SELECT document_id, evidence_type, weight
|
||||
FROM recommendation_evidence
|
||||
WHERE recommendation_id = $1::uuid
|
||||
ORDER BY evidence_type, weight DESC
|
||||
"""
|
||||
|
||||
_FETCH_LATEST_RECS_FOR_TICKER = """
|
||||
SELECT
|
||||
id, ticker, action, mode, confidence, time_horizon,
|
||||
thesis, invalidation_conditions, portfolio_pct, max_loss_pct,
|
||||
model_version, model_provider, prompt_version, schema_version,
|
||||
risk_classification, generated_at
|
||||
FROM recommendations
|
||||
WHERE ticker = $1
|
||||
ORDER BY generated_at DESC
|
||||
LIMIT $2
|
||||
"""
|
||||
|
||||
|
||||
def _extract_risk_classification(thesis: str) -> str:
|
||||
"""Extract the risk classification from the thesis prefix."""
|
||||
if thesis.startswith("[risk:"):
|
||||
end = thesis.find("]")
|
||||
if end > 6:
|
||||
return thesis[6:end]
|
||||
return "moderate"
|
||||
|
||||
|
||||
async def persist_recommendation(
|
||||
pool: asyncpg.Pool,
|
||||
rec: Recommendation,
|
||||
supporting_ids: list[str],
|
||||
opposing_ids: list[str],
|
||||
eligibility_result: EligibilityResult | None = None,
|
||||
) -> str:
|
||||
"""Insert a recommendation, evidence citations, and risk evaluation.
|
||||
|
||||
Persists the full model metadata and risk classification for audit
|
||||
trail (Requirement 8.3). Also writes the eligibility decision to
|
||||
the risk_evaluations table when provided.
|
||||
|
||||
Returns the recommendation UUID.
|
||||
"""
|
||||
risk_class = _extract_risk_classification(rec.thesis)
|
||||
|
||||
row = await pool.fetchrow(
|
||||
_INSERT_RECOMMENDATION,
|
||||
rec.ticker,
|
||||
rec.action.value,
|
||||
rec.mode.value,
|
||||
rec.confidence,
|
||||
rec.time_horizon,
|
||||
rec.thesis,
|
||||
json.dumps(rec.invalidation_conditions),
|
||||
rec.position_sizing.portfolio_pct,
|
||||
rec.position_sizing.max_loss_pct,
|
||||
rec.model_metadata.model_name,
|
||||
rec.model_metadata.provider,
|
||||
rec.model_metadata.prompt_version,
|
||||
rec.model_metadata.schema_version,
|
||||
risk_class,
|
||||
rec.generated_at,
|
||||
)
|
||||
rec_id = str(row["id"])
|
||||
|
||||
# Insert evidence citations with position-based weighting
|
||||
evidence_rows: list[tuple[str, str, str, float]] = []
|
||||
for idx, doc_id in enumerate(supporting_ids):
|
||||
weight = round(1.0 / (1.0 + idx * 0.1), 4) # rank decay
|
||||
evidence_rows.append((rec_id, doc_id, "supporting", weight))
|
||||
for idx, doc_id in enumerate(opposing_ids):
|
||||
weight = round(1.0 / (1.0 + idx * 0.1), 4)
|
||||
evidence_rows.append((rec_id, doc_id, "opposing", weight))
|
||||
|
||||
if evidence_rows:
|
||||
await pool.executemany(_INSERT_REC_EVIDENCE, evidence_rows)
|
||||
|
||||
# Persist the eligibility/risk evaluation for audit trail
|
||||
if eligibility_result is not None:
|
||||
rejection_reasons_json = json.dumps(
|
||||
[r.value for r in eligibility_result.rejection_reasons]
|
||||
)
|
||||
risk_checks = {
|
||||
"time_horizon": eligibility_result.time_horizon,
|
||||
"position_sizing": {
|
||||
"portfolio_pct": eligibility_result.position_sizing.portfolio_pct,
|
||||
"max_loss_pct": eligibility_result.position_sizing.max_loss_pct,
|
||||
},
|
||||
"invalidation_conditions": eligibility_result.invalidation_conditions,
|
||||
"risk_classification": risk_class,
|
||||
}
|
||||
await pool.execute(
|
||||
_INSERT_RISK_EVALUATION,
|
||||
rec_id,
|
||||
eligibility_result.eligible,
|
||||
eligibility_result.mode.value,
|
||||
rejection_reasons_json,
|
||||
json.dumps(risk_checks),
|
||||
rec.generated_at,
|
||||
)
|
||||
|
||||
return rec_id
|
||||
|
||||
|
||||
async def fetch_recommendation_by_id(
|
||||
pool: asyncpg.Pool,
|
||||
recommendation_id: str,
|
||||
) -> dict[str, object] | None:
|
||||
"""Fetch a persisted recommendation with its evidence citations.
|
||||
|
||||
Returns a dict with the recommendation fields and an 'evidence' list,
|
||||
or None if not found.
|
||||
"""
|
||||
row = await pool.fetchrow(_FETCH_RECOMMENDATION, recommendation_id)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
rec_dict = dict(row)
|
||||
# Parse JSONB fields
|
||||
if isinstance(rec_dict.get("invalidation_conditions"), str):
|
||||
rec_dict["invalidation_conditions"] = json.loads(rec_dict["invalidation_conditions"])
|
||||
|
||||
# Fetch evidence
|
||||
evidence_rows = await pool.fetch(_FETCH_REC_EVIDENCE, recommendation_id)
|
||||
rec_dict["evidence"] = [
|
||||
{
|
||||
"document_id": str(e["document_id"]),
|
||||
"evidence_type": e["evidence_type"],
|
||||
"weight": float(e["weight"]),
|
||||
}
|
||||
for e in evidence_rows
|
||||
]
|
||||
|
||||
return rec_dict
|
||||
|
||||
|
||||
async def fetch_latest_recommendations(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
limit: int = 10,
|
||||
) -> list[dict[str, object]]:
|
||||
"""Fetch the most recent recommendations for a ticker.
|
||||
|
||||
Returns a list of recommendation dicts (without evidence — use
|
||||
fetch_recommendation_by_id for full detail).
|
||||
"""
|
||||
rows = await pool.fetch(_FETCH_LATEST_RECS_FOR_TICKER, ticker, limit)
|
||||
results = []
|
||||
for row in rows:
|
||||
rec_dict = dict(row)
|
||||
if isinstance(rec_dict.get("invalidation_conditions"), str):
|
||||
rec_dict["invalidation_conditions"] = json.loads(rec_dict["invalidation_conditions"])
|
||||
results.append(rec_dict)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point: generate recommendation for a ticker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def generate_recommendation(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window: str = TrendWindow.SEVEN_DAY.value,
|
||||
config: EligibilityConfig | None = None,
|
||||
reference_time: datetime | None = None,
|
||||
ollama_config: OllamaConfig | None = None,
|
||||
suppression_config: SuppressionConfig | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> Recommendation | None:
|
||||
"""Generate and persist a recommendation for a ticker from its latest trend.
|
||||
|
||||
Steps:
|
||||
1. Fetch the latest trend summary for the ticker + window.
|
||||
2. Evaluate data quality suppression (Requirement 7.4).
|
||||
3. Evaluate eligibility using deterministic rules.
|
||||
4. Build a Recommendation object with thesis and evidence.
|
||||
- If ``ollama_config`` is provided, the deterministic thesis is
|
||||
rewritten into analyst-quality prose via the LLM wording layer.
|
||||
5. Persist the recommendation and evidence citations.
|
||||
|
||||
Returns the Recommendation, or None if no trend data exists.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
cfg = config or EligibilityConfig()
|
||||
sup_cfg = suppression_config or SuppressionConfig()
|
||||
|
||||
# 1. Fetch latest trend
|
||||
summary = await fetch_latest_trend(pool, ticker, window)
|
||||
if summary is None:
|
||||
logger.info("No trend data for %s/%s — skipping recommendation", ticker, window)
|
||||
return None
|
||||
|
||||
# 2. Evaluate data quality suppression (Requirement 7.4)
|
||||
quality_ctx = await fetch_data_quality_context(pool, ticker, window)
|
||||
suppression = evaluate_suppression(
|
||||
summary, quality_ctx=quality_ctx, config=sup_cfg, reference_time=reference_time,
|
||||
)
|
||||
|
||||
# 3. Evaluate eligibility
|
||||
result = evaluate_eligibility(summary, cfg)
|
||||
|
||||
# Apply suppression: force mode to informational if suppressed
|
||||
if suppression.suppressed:
|
||||
result = EligibilityResult(
|
||||
eligible=False,
|
||||
action=result.action,
|
||||
mode=RecommendationMode.INFORMATIONAL,
|
||||
position_sizing=result.position_sizing,
|
||||
rejection_reasons=result.rejection_reasons,
|
||||
time_horizon=result.time_horizon,
|
||||
invalidation_conditions=result.invalidation_conditions,
|
||||
)
|
||||
|
||||
# 4. Optional LLM thesis rewrite
|
||||
llm_thesis: str | None = None
|
||||
if ollama_config is not None:
|
||||
deterministic_thesis = build_thesis(summary, result)
|
||||
llm_thesis = await rewrite_thesis_with_llm(
|
||||
deterministic_thesis=deterministic_thesis,
|
||||
summary=summary,
|
||||
config=ollama_config,
|
||||
)
|
||||
# If the LLM returned the same text as the deterministic thesis,
|
||||
# treat it as a no-op (fallback was used).
|
||||
if llm_thesis == deterministic_thesis:
|
||||
llm_thesis = None
|
||||
|
||||
# 5. Build recommendation
|
||||
rec = build_recommendation(
|
||||
summary, result, reference_time, llm_thesis=llm_thesis,
|
||||
suppression_result=suppression,
|
||||
)
|
||||
|
||||
# 6. Persist recommendation, evidence citations, and risk evaluation
|
||||
rec_id = await persist_recommendation(
|
||||
pool,
|
||||
rec,
|
||||
supporting_ids=list(summary.top_supporting_evidence),
|
||||
opposing_ids=list(summary.top_opposing_evidence),
|
||||
eligibility_result=result,
|
||||
)
|
||||
|
||||
# 7. Publish prediction facts to analytical tables (Requirement 9.4)
|
||||
if minio_client is not None:
|
||||
try:
|
||||
lake_refs = publish_recommendation_facts(
|
||||
minio_client,
|
||||
rec,
|
||||
trend_direction=summary.trend_direction.value,
|
||||
trend_strength=summary.trend_strength,
|
||||
)
|
||||
logger.info(
|
||||
"Published analytical facts for %s: %s",
|
||||
ticker, lake_refs,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to publish analytical facts for %s/%s — recommendation "
|
||||
"persisted but lake publication failed",
|
||||
ticker, rec_id, exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Generated recommendation %s for %s: action=%s mode=%s confidence=%.3f "
|
||||
"eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s",
|
||||
rec_id, ticker, rec.action.value, rec.mode.value, rec.confidence,
|
||||
result.eligible, suppression.suppressed, suppression.data_quality_score,
|
||||
llm_thesis is not None,
|
||||
)
|
||||
|
||||
# Prometheus metrics
|
||||
RECOMMENDATION_GENERATED.labels(action=rec.action.value, mode=rec.mode.value).inc()
|
||||
RECOMMENDATION_CONFIDENCE.observe(rec.confidence)
|
||||
if suppression.suppressed:
|
||||
RECOMMENDATION_SUPPRESSED.inc()
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch: generate recommendations for multiple tickers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def generate_recommendations_batch(
|
||||
pool: asyncpg.Pool,
|
||||
tickers: list[str],
|
||||
window: str = TrendWindow.SEVEN_DAY.value,
|
||||
config: EligibilityConfig | None = None,
|
||||
ollama_config: OllamaConfig | None = None,
|
||||
suppression_config: SuppressionConfig | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> list[Recommendation]:
|
||||
"""Generate recommendations for a list of tickers.
|
||||
|
||||
Processes each ticker sequentially. Returns only the successfully
|
||||
generated recommendations (tickers with no trend data are skipped).
|
||||
|
||||
If ``ollama_config`` is provided, each recommendation's thesis will
|
||||
be rewritten using the LLM wording layer.
|
||||
"""
|
||||
results: list[Recommendation] = []
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
|
||||
for ticker in tickers:
|
||||
rec = await generate_recommendation(
|
||||
pool, ticker, window, config, reference_time,
|
||||
ollama_config=ollama_config,
|
||||
suppression_config=suppression_config,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
if rec is not None:
|
||||
results.append(rec)
|
||||
|
||||
logger.info(
|
||||
"Batch recommendation: %d/%d tickers produced recommendations",
|
||||
len(results), len(tickers),
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Risk Engine API - FastAPI application for order risk evaluation and approval workflow."""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncpg
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.risk.approval import (
|
||||
expire_stale_approvals,
|
||||
get_approval_by_id,
|
||||
get_pending_approvals,
|
||||
review_approval,
|
||||
)
|
||||
from services.risk.engine import (
|
||||
AccountRiskState,
|
||||
PortfolioRiskConfig,
|
||||
ProposedOrder,
|
||||
RiskEvaluation,
|
||||
evaluate_order,
|
||||
)
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import setup_logging
|
||||
|
||||
config = load_config()
|
||||
pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global pool
|
||||
setup_logging("risk_engine", level=config.log_level, json_output=config.json_logs)
|
||||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||||
yield
|
||||
if pool:
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Stonks Oracle - Risk Engine", lifespan=lifespan)
|
||||
|
||||
|
||||
class EvaluateRequest(BaseModel):
|
||||
order: ProposedOrder
|
||||
config: PortfolioRiskConfig | None = None
|
||||
state: AccountRiskState | None = None
|
||||
|
||||
|
||||
@app.post("/evaluate", response_model=RiskEvaluation)
|
||||
async def evaluate(req: EvaluateRequest) -> RiskEvaluation:
|
||||
risk_config = req.config or PortfolioRiskConfig()
|
||||
return evaluate_order(req.order, risk_config, req.state)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class ReviewRequest(BaseModel):
|
||||
approved: bool
|
||||
reviewed_by: str = "operator"
|
||||
review_note: str = ""
|
||||
|
||||
|
||||
@app.get("/approvals/pending")
|
||||
async def list_pending():
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
requests = await get_pending_approvals(pool)
|
||||
return [r.to_dict() for r in requests]
|
||||
|
||||
|
||||
@app.get("/approvals/{approval_id}")
|
||||
async def get_approval(approval_id: str):
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
req = await get_approval_by_id(pool, approval_id)
|
||||
if not req:
|
||||
raise HTTPException(404, "Approval not found")
|
||||
return req.to_dict()
|
||||
|
||||
|
||||
@app.post("/approvals/{approval_id}/review")
|
||||
async def review(approval_id: str, body: ReviewRequest):
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
status = await review_approval(
|
||||
pool, approval_id, body.approved, body.reviewed_by, body.review_note,
|
||||
)
|
||||
if status is None:
|
||||
raise HTTPException(404, "Approval not found or no longer pending")
|
||||
return {"approval_id": approval_id, "status": status.value}
|
||||
|
||||
|
||||
@app.post("/approvals/expire")
|
||||
async def expire():
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
expired = await expire_stale_approvals(pool)
|
||||
return {"expired": expired}
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Operator approval workflow for live trading mode.
|
||||
|
||||
When live trading is enabled and operator approval is required,
|
||||
orders are held in a pending state until an operator explicitly
|
||||
approves or rejects them. Expired approvals are treated as rejections.
|
||||
|
||||
Requirements: 8.2
|
||||
Design: Section 4.8 - Risk Engine (operator approval rules)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.risk.engine import (
|
||||
OperatorApproval,
|
||||
PortfolioRiskConfig,
|
||||
TradingMode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("operator_approval")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core logic: does this order need approval?
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def requires_approval(
|
||||
config: PortfolioRiskConfig,
|
||||
trading_mode: TradingMode | None = None,
|
||||
) -> bool:
|
||||
"""Determine whether an order requires operator approval.
|
||||
|
||||
Paper orders are auto-approved when auto_approve_paper is True.
|
||||
Live orders require approval when require_approval_for_live is True.
|
||||
Disabled mode always returns False (orders are blocked upstream).
|
||||
"""
|
||||
mode = trading_mode or config.trading_mode
|
||||
|
||||
if mode == TradingMode.DISABLED:
|
||||
return False
|
||||
|
||||
if mode == TradingMode.PAPER:
|
||||
return not config.operator_approval.auto_approve_paper
|
||||
|
||||
# Live mode
|
||||
return config.operator_approval.require_approval_for_live
|
||||
|
||||
|
||||
def compute_expiry(
|
||||
config: PortfolioRiskConfig,
|
||||
now: datetime | None = None,
|
||||
) -> datetime:
|
||||
"""Compute the expiry timestamp for a new approval request."""
|
||||
now = now or datetime.now(timezone.utc)
|
||||
return now + timedelta(minutes=config.operator_approval.approval_timeout_minutes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Approval request model (in-memory representation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApprovalRequest:
|
||||
"""Represents a pending operator approval request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approval_id: str | None = None,
|
||||
order_job: dict[str, Any] | None = None,
|
||||
recommendation_id: str | None = None,
|
||||
ticker: str = "",
|
||||
side: str = "buy",
|
||||
quantity: float = 0.0,
|
||||
estimated_value: float = 0.0,
|
||||
risk_evaluation_id: str | None = None,
|
||||
status: ApprovalStatus = ApprovalStatus.PENDING,
|
||||
requested_by: str = "system",
|
||||
reviewed_by: str | None = None,
|
||||
review_note: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
requested_at: datetime | None = None,
|
||||
reviewed_at: datetime | None = None,
|
||||
) -> None:
|
||||
self.approval_id = approval_id or str(uuid.uuid4())
|
||||
self.order_job = order_job or {}
|
||||
self.recommendation_id = recommendation_id
|
||||
self.ticker = ticker
|
||||
self.side = side
|
||||
self.quantity = quantity
|
||||
self.estimated_value = estimated_value
|
||||
self.risk_evaluation_id = risk_evaluation_id
|
||||
self.status = status
|
||||
self.requested_by = requested_by
|
||||
self.reviewed_by = reviewed_by
|
||||
self.review_note = review_note
|
||||
self.expires_at = expires_at or (datetime.now(timezone.utc) + timedelta(minutes=30))
|
||||
self.requested_at = requested_at or datetime.now(timezone.utc)
|
||||
self.reviewed_at = reviewed_at
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status == ApprovalStatus.PENDING
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
if self.status == ApprovalStatus.EXPIRED:
|
||||
return True
|
||||
if self.status == ApprovalStatus.PENDING:
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"approval_id": self.approval_id,
|
||||
"recommendation_id": self.recommendation_id,
|
||||
"ticker": self.ticker,
|
||||
"side": self.side,
|
||||
"quantity": self.quantity,
|
||||
"estimated_value": self.estimated_value,
|
||||
"risk_evaluation_id": self.risk_evaluation_id,
|
||||
"status": self.status.value,
|
||||
"requested_by": self.requested_by,
|
||||
"reviewed_by": self.reviewed_by,
|
||||
"review_note": self.review_note,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"requested_at": self.requested_at.isoformat() if self.requested_at else None,
|
||||
"reviewed_at": self.reviewed_at.isoformat() if self.reviewed_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_APPROVAL = """
|
||||
INSERT INTO operator_approvals (
|
||||
id, order_job, recommendation_id, ticker, side, quantity,
|
||||
estimated_value, status, risk_evaluation_id, requested_by,
|
||||
expires_at, requested_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::jsonb, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10,
|
||||
$11, $12
|
||||
)
|
||||
"""
|
||||
|
||||
_UPDATE_APPROVAL_STATUS = """
|
||||
UPDATE operator_approvals
|
||||
SET status = $2, reviewed_by = $3, review_note = $4, reviewed_at = $5, updated_at = NOW()
|
||||
WHERE id = $1::uuid AND status = 'pending'
|
||||
RETURNING id, status
|
||||
"""
|
||||
|
||||
_EXPIRE_STALE_APPROVALS = """
|
||||
UPDATE operator_approvals
|
||||
SET status = 'expired', updated_at = NOW()
|
||||
WHERE status = 'pending' AND expires_at <= $1
|
||||
RETURNING id, ticker
|
||||
"""
|
||||
|
||||
_FETCH_PENDING_APPROVALS = """
|
||||
SELECT id, order_job, recommendation_id, ticker, side, quantity,
|
||||
estimated_value, status, risk_evaluation_id, requested_by,
|
||||
reviewed_by, review_note, expires_at, requested_at, reviewed_at
|
||||
FROM operator_approvals
|
||||
WHERE status = 'pending'
|
||||
ORDER BY requested_at ASC
|
||||
"""
|
||||
|
||||
_FETCH_APPROVAL_BY_ID = """
|
||||
SELECT id, order_job, recommendation_id, ticker, side, quantity,
|
||||
estimated_value, status, risk_evaluation_id, requested_by,
|
||||
reviewed_by, review_note, expires_at, requested_at, reviewed_at
|
||||
FROM operator_approvals
|
||||
WHERE id = $1::uuid
|
||||
"""
|
||||
|
||||
|
||||
def _row_to_request(row: Any) -> ApprovalRequest:
|
||||
"""Convert a DB row to an ApprovalRequest."""
|
||||
order_job = row["order_job"]
|
||||
if isinstance(order_job, str):
|
||||
order_job = json.loads(order_job)
|
||||
return ApprovalRequest(
|
||||
approval_id=str(row["id"]),
|
||||
order_job=order_job,
|
||||
recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else None,
|
||||
ticker=row["ticker"],
|
||||
side=row["side"],
|
||||
quantity=float(row["quantity"]),
|
||||
estimated_value=float(row["estimated_value"]),
|
||||
risk_evaluation_id=str(row["risk_evaluation_id"]) if row.get("risk_evaluation_id") else None,
|
||||
status=ApprovalStatus(row["status"]),
|
||||
requested_by=row["requested_by"],
|
||||
reviewed_by=row["reviewed_by"],
|
||||
review_note=row["review_note"],
|
||||
expires_at=row["expires_at"],
|
||||
requested_at=row["requested_at"],
|
||||
reviewed_at=row["reviewed_at"],
|
||||
)
|
||||
|
||||
|
||||
async def create_approval_request(
|
||||
pool: asyncpg.Pool,
|
||||
request: ApprovalRequest,
|
||||
) -> str:
|
||||
"""Persist a new approval request. Returns the approval ID."""
|
||||
await pool.execute(
|
||||
_INSERT_APPROVAL,
|
||||
request.approval_id,
|
||||
json.dumps(request.order_job, default=str),
|
||||
request.recommendation_id,
|
||||
request.ticker,
|
||||
request.side,
|
||||
request.quantity,
|
||||
request.estimated_value,
|
||||
request.status.value,
|
||||
request.risk_evaluation_id,
|
||||
request.requested_by,
|
||||
request.expires_at,
|
||||
request.requested_at,
|
||||
)
|
||||
return request.approval_id
|
||||
|
||||
|
||||
async def review_approval(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
approved: bool,
|
||||
reviewed_by: str = "operator",
|
||||
review_note: str = "",
|
||||
) -> ApprovalStatus | None:
|
||||
"""Approve or reject a pending approval request.
|
||||
|
||||
Returns the new status, or None if the approval was not found
|
||||
or was no longer pending (already expired/reviewed).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
new_status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED
|
||||
|
||||
row = await pool.fetchrow(
|
||||
_UPDATE_APPROVAL_STATUS,
|
||||
approval_id,
|
||||
new_status.value,
|
||||
reviewed_by,
|
||||
review_note,
|
||||
now,
|
||||
)
|
||||
if row:
|
||||
return ApprovalStatus(row["status"])
|
||||
return None
|
||||
|
||||
|
||||
async def expire_stale_approvals(
|
||||
pool: asyncpg.Pool,
|
||||
now: datetime | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Mark all expired pending approvals. Returns list of expired items."""
|
||||
now = now or datetime.now(timezone.utc)
|
||||
rows = await pool.fetch(_EXPIRE_STALE_APPROVALS, now)
|
||||
return [{"id": str(r["id"]), "ticker": r["ticker"]} for r in rows]
|
||||
|
||||
|
||||
async def get_pending_approvals(
|
||||
pool: asyncpg.Pool,
|
||||
) -> list[ApprovalRequest]:
|
||||
"""Fetch all pending approval requests, oldest first."""
|
||||
rows = await pool.fetch(_FETCH_PENDING_APPROVALS)
|
||||
return [_row_to_request(r) for r in rows]
|
||||
|
||||
|
||||
async def get_approval_by_id(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
) -> ApprovalRequest | None:
|
||||
"""Fetch a single approval request by ID."""
|
||||
row = await pool.fetchrow(_FETCH_APPROVAL_BY_ID, approval_id)
|
||||
if row:
|
||||
return _row_to_request(row)
|
||||
return None
|
||||
+616
-1
@@ -1 +1,616 @@
|
||||
"""Risk engine - enforces guardrails, position limits, and trade eligibility checks."""
|
||||
"""Risk engine - portfolio and account risk configuration and enforcement.
|
||||
|
||||
Defines the configuration and state models used to enforce guardrails
|
||||
on trade execution: max position size, sector exposure, daily loss limits,
|
||||
news-shock lockouts, and operator approval rules.
|
||||
|
||||
Also implements the hard-block evaluation logic that decides whether a
|
||||
proposed order is allowed before it reaches the broker adapter.
|
||||
|
||||
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
|
||||
Design: Section 4.8 - Risk Engine
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TradingMode(str, Enum):
|
||||
"""Execution environment separation (Requirement 8.1)."""
|
||||
PAPER = "paper"
|
||||
LIVE = "live"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class RiskCheckResult(str, Enum):
|
||||
"""Outcome of a single risk check."""
|
||||
PASS = "pass"
|
||||
FAIL = "fail"
|
||||
WARN = "warn"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Portfolio-level risk configuration (Requirement 8.2, 8.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PositionLimits(BaseModel):
|
||||
"""Per-position size constraints."""
|
||||
max_position_pct: float = Field(
|
||||
default=0.05, ge=0, le=1,
|
||||
description="Maximum portfolio percentage for a single position",
|
||||
)
|
||||
max_position_value: float = Field(
|
||||
default=10_000.0, ge=0,
|
||||
description="Maximum dollar value for a single position",
|
||||
)
|
||||
max_shares_per_order: float = Field(
|
||||
default=1000.0, ge=0,
|
||||
description="Maximum shares in a single order",
|
||||
)
|
||||
|
||||
|
||||
class SectorExposureLimits(BaseModel):
|
||||
"""Sector-level concentration limits."""
|
||||
max_sector_pct: float = Field(
|
||||
default=0.25, ge=0, le=1,
|
||||
description="Maximum portfolio percentage exposed to one sector",
|
||||
)
|
||||
max_sectors: int = Field(
|
||||
default=10, ge=1,
|
||||
description="Maximum number of sectors with open positions",
|
||||
)
|
||||
|
||||
|
||||
class DailyLossLimits(BaseModel):
|
||||
"""Daily drawdown controls."""
|
||||
max_daily_loss_pct: float = Field(
|
||||
default=0.02, ge=0, le=1,
|
||||
description="Maximum portfolio loss percentage in a single day before halting",
|
||||
)
|
||||
max_daily_loss_value: float = Field(
|
||||
default=1_000.0, ge=0,
|
||||
description="Maximum dollar loss in a single day before halting",
|
||||
)
|
||||
max_daily_trades: int = Field(
|
||||
default=20, ge=0,
|
||||
description="Maximum number of trades per day",
|
||||
)
|
||||
|
||||
|
||||
class NewsShockLockout(BaseModel):
|
||||
"""News-shock lockout configuration.
|
||||
|
||||
When a symbol has a high-impact news event, trading is paused
|
||||
for a configurable cooldown period.
|
||||
"""
|
||||
enabled: bool = True
|
||||
lockout_minutes: int = Field(
|
||||
default=60, ge=0,
|
||||
description="Minutes to lock out trading after a high-impact news event",
|
||||
)
|
||||
impact_threshold: float = Field(
|
||||
default=0.80, ge=0, le=1,
|
||||
description="Minimum impact_score from document intelligence to trigger lockout",
|
||||
)
|
||||
catalyst_types: list[str] = Field(
|
||||
default_factory=lambda: ["earnings", "legal", "m_and_a"],
|
||||
description="Catalyst types that trigger lockout when above threshold",
|
||||
)
|
||||
|
||||
|
||||
class OperatorApproval(BaseModel):
|
||||
"""Operator approval workflow for live trading (Requirement 8.2)."""
|
||||
require_approval_for_live: bool = Field(
|
||||
default=True,
|
||||
description="Whether live orders require operator approval",
|
||||
)
|
||||
auto_approve_paper: bool = Field(
|
||||
default=True,
|
||||
description="Whether paper orders are auto-approved",
|
||||
)
|
||||
approval_timeout_minutes: int = Field(
|
||||
default=30, ge=1,
|
||||
description="Minutes before a pending approval expires",
|
||||
)
|
||||
|
||||
|
||||
class SymbolCooldown(BaseModel):
|
||||
"""Per-symbol cooldown after a trade."""
|
||||
cooldown_minutes: int = Field(
|
||||
default=15, ge=0,
|
||||
description="Minutes to wait before trading the same symbol again",
|
||||
)
|
||||
max_open_positions_per_symbol: int = Field(
|
||||
default=1, ge=1,
|
||||
description="Maximum concurrent open positions for a single symbol",
|
||||
)
|
||||
|
||||
|
||||
class PortfolioRiskConfig(BaseModel):
|
||||
"""Complete portfolio-level risk configuration.
|
||||
|
||||
This is the top-level config that governs all risk checks.
|
||||
Persisted in PostgreSQL and loaded at engine startup.
|
||||
"""
|
||||
config_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str = "default"
|
||||
trading_mode: TradingMode = TradingMode.PAPER
|
||||
position_limits: PositionLimits = Field(default_factory=PositionLimits)
|
||||
sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits)
|
||||
daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits)
|
||||
news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout)
|
||||
operator_approval: OperatorApproval = Field(default_factory=OperatorApproval)
|
||||
symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown)
|
||||
active: bool = True
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_db_json(self) -> dict[str, Any]:
|
||||
"""Serialize the full config to a JSON-compatible dict for DB storage."""
|
||||
return self.model_dump(mode="json")
|
||||
|
||||
@classmethod
|
||||
def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig:
|
||||
"""Deserialize from a DB JSON column."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account risk state (runtime snapshot)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AccountRiskState(BaseModel):
|
||||
"""Runtime snapshot of an account's risk posture.
|
||||
|
||||
Computed from broker positions, today's trades, and current P&L.
|
||||
Used by risk checks to evaluate whether a new order is allowed.
|
||||
"""
|
||||
account_id: str = ""
|
||||
portfolio_value: float = 0.0
|
||||
cash: float = 0.0
|
||||
buying_power: float = 0.0
|
||||
daily_pnl: float = 0.0
|
||||
daily_trade_count: int = 0
|
||||
open_position_count: int = 0
|
||||
positions_by_symbol: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of ticker → current market value",
|
||||
)
|
||||
positions_by_sector: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of sector → total market value",
|
||||
)
|
||||
last_trade_times: dict[str, datetime] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of ticker → last trade timestamp for cooldown checks",
|
||||
)
|
||||
locked_symbols: dict[str, datetime] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of ticker → lockout expiry for news-shock lockouts",
|
||||
)
|
||||
snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Risk check output (Requirement 8.3 - full decision trace)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RiskCheckDetail(BaseModel):
|
||||
"""Result of a single risk check."""
|
||||
check_name: str
|
||||
result: RiskCheckResult
|
||||
message: str = ""
|
||||
threshold: float | None = None
|
||||
actual: float | None = None
|
||||
|
||||
|
||||
class RiskEvaluation(BaseModel):
|
||||
"""Complete risk evaluation for a proposed order.
|
||||
|
||||
Captures every check performed so the full decision trace
|
||||
is reproducible (Requirement 8.3).
|
||||
"""
|
||||
evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
recommendation_id: str | None = None
|
||||
ticker: str = ""
|
||||
eligible: bool = False
|
||||
allowed_mode: TradingMode = TradingMode.DISABLED
|
||||
checks: list[RiskCheckDetail] = Field(default_factory=list)
|
||||
rejection_reasons: list[str] = Field(default_factory=list)
|
||||
config_snapshot: PortfolioRiskConfig | None = None
|
||||
state_snapshot: AccountRiskState | None = None
|
||||
evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@property
|
||||
def passed(self) -> bool:
|
||||
return self.eligible and len(self.rejection_reasons) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_RISK_CONFIG = PortfolioRiskConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proposed order (input to risk evaluation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProposedOrder(BaseModel):
|
||||
"""A proposed order to be evaluated by the risk engine before submission.
|
||||
|
||||
This is the input to evaluate_order(). It carries enough context
|
||||
for every risk check to run without external lookups.
|
||||
"""
|
||||
recommendation_id: str | None = None
|
||||
ticker: str
|
||||
sector: str = ""
|
||||
action: str = "buy" # buy | sell
|
||||
quantity: float = 0.0
|
||||
estimated_value: float = 0.0
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Individual risk checks (Requirement 8.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_trading_mode(
|
||||
config: PortfolioRiskConfig,
|
||||
) -> RiskCheckDetail:
|
||||
"""Block all orders when trading is disabled."""
|
||||
if config.trading_mode == TradingMode.DISABLED:
|
||||
return RiskCheckDetail(
|
||||
check_name="trading_mode",
|
||||
result=RiskCheckResult.FAIL,
|
||||
message="Trading is disabled",
|
||||
)
|
||||
return RiskCheckDetail(
|
||||
check_name="trading_mode",
|
||||
result=RiskCheckResult.PASS,
|
||||
message=f"Trading mode: {config.trading_mode.value}",
|
||||
)
|
||||
|
||||
|
||||
def _check_max_position_size(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
) -> list[RiskCheckDetail]:
|
||||
"""Enforce per-position size limits (value, percentage, shares)."""
|
||||
checks: list[RiskCheckDetail] = []
|
||||
limits = config.position_limits
|
||||
|
||||
# Check max position value
|
||||
existing_value = state.positions_by_symbol.get(order.ticker, 0.0)
|
||||
new_total_value = existing_value + order.estimated_value
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="max_position_value",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if new_total_value <= limits.max_position_value
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Position value {new_total_value:.2f} "
|
||||
f"{'within' if new_total_value <= limits.max_position_value else 'exceeds'} "
|
||||
f"limit {limits.max_position_value:.2f}"
|
||||
),
|
||||
threshold=limits.max_position_value,
|
||||
actual=new_total_value,
|
||||
))
|
||||
|
||||
# Check max position percentage of portfolio
|
||||
if state.portfolio_value > 0:
|
||||
position_pct = new_total_value / state.portfolio_value
|
||||
else:
|
||||
position_pct = 1.0 if new_total_value > 0 else 0.0
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="max_position_pct",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if position_pct <= limits.max_position_pct
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Position {position_pct:.4f} of portfolio "
|
||||
f"{'within' if position_pct <= limits.max_position_pct else 'exceeds'} "
|
||||
f"limit {limits.max_position_pct:.4f}"
|
||||
),
|
||||
threshold=limits.max_position_pct,
|
||||
actual=position_pct,
|
||||
))
|
||||
|
||||
# Check max shares per order
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="max_shares_per_order",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if order.quantity <= limits.max_shares_per_order
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Order quantity {order.quantity:.0f} "
|
||||
f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} "
|
||||
f"limit {limits.max_shares_per_order:.0f}"
|
||||
),
|
||||
threshold=limits.max_shares_per_order,
|
||||
actual=order.quantity,
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
def _check_sector_exposure(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
) -> RiskCheckDetail:
|
||||
"""Enforce sector concentration limits."""
|
||||
limits = config.sector_exposure
|
||||
|
||||
if not order.sector:
|
||||
return RiskCheckDetail(
|
||||
check_name="sector_exposure",
|
||||
result=RiskCheckResult.WARN,
|
||||
message="No sector provided on order; skipping sector check",
|
||||
)
|
||||
|
||||
existing_sector_value = state.positions_by_sector.get(order.sector, 0.0)
|
||||
new_sector_value = existing_sector_value + order.estimated_value
|
||||
|
||||
if state.portfolio_value > 0:
|
||||
sector_pct = new_sector_value / state.portfolio_value
|
||||
else:
|
||||
sector_pct = 1.0 if new_sector_value > 0 else 0.0
|
||||
|
||||
return RiskCheckDetail(
|
||||
check_name="sector_exposure",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if sector_pct <= limits.max_sector_pct
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Sector '{order.sector}' exposure {sector_pct:.4f} "
|
||||
f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} "
|
||||
f"limit {limits.max_sector_pct:.4f}"
|
||||
),
|
||||
threshold=limits.max_sector_pct,
|
||||
actual=sector_pct,
|
||||
)
|
||||
|
||||
|
||||
def _check_daily_loss(
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
) -> list[RiskCheckDetail]:
|
||||
"""Enforce daily loss and trade count limits."""
|
||||
checks: list[RiskCheckDetail] = []
|
||||
limits = config.daily_loss
|
||||
|
||||
# Daily loss percentage
|
||||
if state.portfolio_value > 0:
|
||||
loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value
|
||||
else:
|
||||
loss_pct = 0.0
|
||||
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="daily_loss_pct",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if loss_pct <= limits.max_daily_loss_pct
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Daily loss {loss_pct:.4f} "
|
||||
f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} "
|
||||
f"limit {limits.max_daily_loss_pct:.4f}"
|
||||
),
|
||||
threshold=limits.max_daily_loss_pct,
|
||||
actual=loss_pct,
|
||||
))
|
||||
|
||||
# Daily loss absolute value
|
||||
abs_loss = abs(min(state.daily_pnl, 0.0))
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="daily_loss_value",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if abs_loss <= limits.max_daily_loss_value
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Daily loss ${abs_loss:.2f} "
|
||||
f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} "
|
||||
f"limit ${limits.max_daily_loss_value:.2f}"
|
||||
),
|
||||
threshold=limits.max_daily_loss_value,
|
||||
actual=abs_loss,
|
||||
))
|
||||
|
||||
# Daily trade count
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="daily_trade_count",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if state.daily_trade_count < limits.max_daily_trades
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Daily trades {state.daily_trade_count} "
|
||||
f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} "
|
||||
f"limit {limits.max_daily_trades}"
|
||||
),
|
||||
threshold=float(limits.max_daily_trades),
|
||||
actual=float(state.daily_trade_count),
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
def _check_news_shock_lockout(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
now: datetime | None = None,
|
||||
) -> RiskCheckDetail:
|
||||
"""Block trading on symbols under news-shock lockout."""
|
||||
lockout_cfg = config.news_shock
|
||||
|
||||
if not lockout_cfg.enabled:
|
||||
return RiskCheckDetail(
|
||||
check_name="news_shock_lockout",
|
||||
result=RiskCheckResult.PASS,
|
||||
message="News-shock lockout is disabled",
|
||||
)
|
||||
|
||||
now = now or datetime.now(timezone.utc)
|
||||
lockout_expiry = state.locked_symbols.get(order.ticker)
|
||||
|
||||
if lockout_expiry is not None and now < lockout_expiry:
|
||||
remaining = lockout_expiry - now
|
||||
return RiskCheckDetail(
|
||||
check_name="news_shock_lockout",
|
||||
result=RiskCheckResult.FAIL,
|
||||
message=(
|
||||
f"Symbol {order.ticker} locked out until "
|
||||
f"{lockout_expiry.isoformat()} "
|
||||
f"({remaining.total_seconds():.0f}s remaining)"
|
||||
),
|
||||
)
|
||||
|
||||
return RiskCheckDetail(
|
||||
check_name="news_shock_lockout",
|
||||
result=RiskCheckResult.PASS,
|
||||
message=f"No active lockout for {order.ticker}",
|
||||
)
|
||||
|
||||
|
||||
def _check_symbol_cooldown(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
now: datetime | None = None,
|
||||
) -> RiskCheckDetail:
|
||||
"""Enforce per-symbol cooldown between trades."""
|
||||
cooldown_cfg = config.symbol_cooldown
|
||||
now = now or datetime.now(timezone.utc)
|
||||
|
||||
last_trade = state.last_trade_times.get(order.ticker)
|
||||
if last_trade is not None:
|
||||
cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes)
|
||||
if now < cooldown_end:
|
||||
remaining = cooldown_end - now
|
||||
return RiskCheckDetail(
|
||||
check_name="symbol_cooldown",
|
||||
result=RiskCheckResult.FAIL,
|
||||
message=(
|
||||
f"Symbol {order.ticker} in cooldown until "
|
||||
f"{cooldown_end.isoformat()} "
|
||||
f"({remaining.total_seconds():.0f}s remaining)"
|
||||
),
|
||||
)
|
||||
|
||||
return RiskCheckDetail(
|
||||
check_name="symbol_cooldown",
|
||||
result=RiskCheckResult.PASS,
|
||||
message=f"No active cooldown for {order.ticker}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main evaluation entry point (Requirements 8.3, 8.4, 8.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def evaluate_order(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG,
|
||||
state: AccountRiskState | None = None,
|
||||
now: datetime | None = None,
|
||||
) -> RiskEvaluation:
|
||||
"""Evaluate a proposed order against all risk controls.
|
||||
|
||||
Runs every hard-block check and returns a RiskEvaluation capturing
|
||||
the full decision trace (Requirement 8.3). If any check fails,
|
||||
the order is rejected before broker submission (Requirement 8.4).
|
||||
|
||||
The engine fails closed: if state is missing or ambiguous, the
|
||||
order is rejected (Requirement 8.5).
|
||||
"""
|
||||
state = state or AccountRiskState()
|
||||
now = now or datetime.now(timezone.utc)
|
||||
|
||||
all_checks: list[RiskCheckDetail] = []
|
||||
rejection_reasons: list[str] = []
|
||||
|
||||
# 1. Trading mode gate
|
||||
mode_check = _check_trading_mode(config)
|
||||
all_checks.append(mode_check)
|
||||
if mode_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(mode_check.message)
|
||||
|
||||
# 2. Position size limits
|
||||
position_checks = _check_max_position_size(order, config, state)
|
||||
all_checks.extend(position_checks)
|
||||
for c in position_checks:
|
||||
if c.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(c.message)
|
||||
|
||||
# 3. Sector exposure
|
||||
sector_check = _check_sector_exposure(order, config, state)
|
||||
all_checks.append(sector_check)
|
||||
if sector_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(sector_check.message)
|
||||
|
||||
# 4. Daily loss limits
|
||||
daily_checks = _check_daily_loss(config, state)
|
||||
all_checks.extend(daily_checks)
|
||||
for c in daily_checks:
|
||||
if c.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(c.message)
|
||||
|
||||
# 5. News-shock lockout
|
||||
lockout_check = _check_news_shock_lockout(order, config, state, now)
|
||||
all_checks.append(lockout_check)
|
||||
if lockout_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(lockout_check.message)
|
||||
|
||||
# 6. Symbol cooldown
|
||||
cooldown_check = _check_symbol_cooldown(order, config, state, now)
|
||||
all_checks.append(cooldown_check)
|
||||
if cooldown_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(cooldown_check.message)
|
||||
|
||||
# Determine eligibility and allowed mode
|
||||
eligible = len(rejection_reasons) == 0
|
||||
allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED
|
||||
|
||||
return RiskEvaluation(
|
||||
recommendation_id=order.recommendation_id,
|
||||
ticker=order.ticker,
|
||||
eligible=eligible,
|
||||
allowed_mode=allowed_mode,
|
||||
checks=all_checks,
|
||||
rejection_reasons=rejection_reasons,
|
||||
config_snapshot=config,
|
||||
state_snapshot=state,
|
||||
evaluated_at=now,
|
||||
)
|
||||
|
||||
+238
-43
@@ -1,14 +1,23 @@
|
||||
"""Scheduler - triggers ingestion cycles for tracked symbols and sources."""
|
||||
"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
|
||||
|
||||
Polls the symbol registry for active companies and their configured sources,
|
||||
respects per-source polling cadences and backoff windows, coordinates rate
|
||||
limits across source types, and enqueues ingestion jobs for downstream workers.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_pg_pool, get_redis
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_INGESTION,
|
||||
lock_key,
|
||||
@@ -16,11 +25,11 @@ from services.shared.redis_keys import (
|
||||
rate_limit_key,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("scheduler")
|
||||
|
||||
# Polling cadences by source class (seconds)
|
||||
CADENCES = {
|
||||
# Default polling cadences by source class (seconds).
|
||||
# Individual sources can override via config.polling_interval_seconds.
|
||||
DEFAULT_CADENCES: dict[str, int] = {
|
||||
"market_api": 60,
|
||||
"news_api": 300,
|
||||
"filings_api": 3600,
|
||||
@@ -28,81 +37,267 @@ CADENCES = {
|
||||
"broker": 30,
|
||||
}
|
||||
|
||||
# Default rate limits per source type (requests per minute)
|
||||
DEFAULT_RATE_LIMITS: dict[str, int] = {
|
||||
"market_api": 30,
|
||||
"news_api": 20,
|
||||
"filings_api": 10,
|
||||
"web_scrape": 10,
|
||||
"broker": 60,
|
||||
}
|
||||
|
||||
# How long to wait before retrying a failed source (seconds)
|
||||
DEFAULT_BACKOFF_BASE: int = 60
|
||||
MAX_BACKOFF: int = 3600
|
||||
MAX_RETRY_COUNT: int = 10
|
||||
|
||||
# Main loop interval (seconds)
|
||||
SCHEDULER_TICK: int = 15
|
||||
|
||||
|
||||
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
|
||||
"""Return the polling interval for a source.
|
||||
|
||||
Uses the source's config.polling_interval_seconds if set,
|
||||
otherwise falls back to the default cadence for the source type.
|
||||
"""
|
||||
if config and "polling_interval_seconds" in config:
|
||||
try:
|
||||
return max(10, int(config["polling_interval_seconds"]))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return DEFAULT_CADENCES.get(source_type, 600)
|
||||
|
||||
|
||||
def compute_backoff(retry_count: int) -> int:
|
||||
"""Exponential backoff with a cap. Returns seconds to wait."""
|
||||
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
|
||||
return min(delay, MAX_BACKOFF)
|
||||
|
||||
|
||||
def is_source_due(
|
||||
source_type: str,
|
||||
source_config: Optional[dict[str, Any]],
|
||||
last_completed_at: Optional[datetime],
|
||||
last_status: Optional[str],
|
||||
retry_count: int,
|
||||
next_retry_at: Optional[datetime],
|
||||
now: datetime,
|
||||
) -> bool:
|
||||
"""Determine whether a source is due for its next polling cycle.
|
||||
|
||||
Checks:
|
||||
- If the source has never run, it is due.
|
||||
- If the last run failed and we have a next_retry_at in the future, skip.
|
||||
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
|
||||
- Otherwise, check if enough time has elapsed since the last completed run.
|
||||
"""
|
||||
# Never run before — always due
|
||||
if last_completed_at is None and last_status is None:
|
||||
return True
|
||||
|
||||
# If last run failed, respect backoff
|
||||
if last_status == "failed":
|
||||
if retry_count >= MAX_RETRY_COUNT:
|
||||
return False
|
||||
if next_retry_at and now < next_retry_at.replace(tzinfo=None):
|
||||
return False
|
||||
# Backoff elapsed or no next_retry_at set — allow retry
|
||||
return True
|
||||
|
||||
# If currently running, don't double-schedule
|
||||
if last_status == "running":
|
||||
return False
|
||||
|
||||
# Normal cadence check
|
||||
if last_completed_at is None:
|
||||
return True
|
||||
|
||||
cadence = get_cadence_for_source(source_type, source_config)
|
||||
elapsed = (now - last_completed_at.replace(tzinfo=None)).total_seconds()
|
||||
return elapsed >= cadence
|
||||
|
||||
|
||||
def build_job_payload(
|
||||
source: Any,
|
||||
aliases: list[str],
|
||||
now: datetime,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the ingestion job payload for a source."""
|
||||
return {
|
||||
"source_id": str(source["source_id"]),
|
||||
"company_id": str(source["company_id"]),
|
||||
"ticker": source["ticker"],
|
||||
"legal_name": source["legal_name"],
|
||||
"aliases": aliases,
|
||||
"source_type": source["source_type"],
|
||||
"source_name": source["source_name"],
|
||||
"config": dict(source["config"]) if source["config"] else {},
|
||||
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
|
||||
"scheduled_at": now.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
|
||||
"""Acquire a distributed lock. Returns True if acquired."""
|
||||
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
|
||||
|
||||
|
||||
async def release_lock(rds: aioredis.Redis, name: str):
|
||||
async def release_lock(rds: aioredis.Redis, name: str) -> None:
|
||||
"""Release a distributed lock."""
|
||||
await rds.delete(lock_key(name))
|
||||
|
||||
|
||||
async def check_rate_limit(rds: aioredis.Redis, source_type: str, max_per_minute: int = 30) -> bool:
|
||||
key = rate_limit_key(source_type, datetime.utcnow().strftime("%Y%m%d%H%M"))
|
||||
async def check_rate_limit(
|
||||
rds: aioredis.Redis,
|
||||
source_type: str,
|
||||
now: datetime,
|
||||
max_per_minute: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Check whether the source type is within its rate limit window.
|
||||
|
||||
Returns True if the request is allowed, False if rate-limited.
|
||||
"""
|
||||
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
|
||||
window = now.strftime("%Y%m%d%H%M")
|
||||
key = rate_limit_key(source_type, window)
|
||||
count = await rds.incr(key)
|
||||
if count == 1:
|
||||
await rds.expire(key, 120)
|
||||
return count <= max_per_minute
|
||||
return count <= limit
|
||||
|
||||
|
||||
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis):
|
||||
"""One scheduling pass: find due sources and enqueue ingestion jobs."""
|
||||
sources = await pool.fetch(
|
||||
"""SELECT s.id as source_id, s.company_id, s.source_type, s.source_name, s.config,
|
||||
c.ticker, c.legal_name
|
||||
FROM sources s JOIN companies c ON s.company_id = c.id
|
||||
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||||
"""Fetch all active sources joined with their active companies."""
|
||||
return await pool.fetch(
|
||||
"""SELECT s.id AS source_id,
|
||||
s.company_id,
|
||||
s.source_type,
|
||||
s.source_name,
|
||||
s.config,
|
||||
s.credibility_score,
|
||||
c.ticker,
|
||||
c.legal_name
|
||||
FROM sources s
|
||||
JOIN companies c ON s.company_id = c.id
|
||||
WHERE s.active = TRUE AND c.active = TRUE
|
||||
ORDER BY s.source_type, c.ticker"""
|
||||
)
|
||||
|
||||
|
||||
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
|
||||
"""Fetch all aliases for a company."""
|
||||
rows = await pool.fetch(
|
||||
"SELECT alias FROM company_aliases WHERE company_id = $1",
|
||||
company_id,
|
||||
)
|
||||
return [r["alias"] for r in rows]
|
||||
|
||||
|
||||
async def fetch_last_run(
|
||||
pool: asyncpg.Pool, source_id: str
|
||||
) -> Optional[asyncpg.Record]:
|
||||
"""Fetch the most recent ingestion run for a source."""
|
||||
return await pool.fetchrow(
|
||||
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
|
||||
FROM ingestion_runs
|
||||
WHERE source_id = $1
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1""",
|
||||
source_id,
|
||||
)
|
||||
|
||||
|
||||
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||||
"""One scheduling pass: find due sources and enqueue ingestion jobs.
|
||||
|
||||
Returns the number of jobs enqueued.
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
sources = await fetch_active_sources(pool)
|
||||
|
||||
enqueued = 0
|
||||
skipped_rate_limit = 0
|
||||
skipped_not_due = 0
|
||||
|
||||
for src in sources:
|
||||
source_id = src["source_id"]
|
||||
source_type = src["source_type"]
|
||||
cadence = CADENCES.get(source_type, 600)
|
||||
source_config = dict(src["config"]) if src["config"] else None
|
||||
|
||||
# Check last run
|
||||
last_run = await pool.fetchval(
|
||||
"SELECT MAX(started_at) FROM ingestion_runs WHERE source_id = $1 AND status IN ('completed', 'running')",
|
||||
src["source_id"],
|
||||
)
|
||||
if last_run and (datetime.utcnow() - last_run.replace(tzinfo=None)) < timedelta(seconds=cadence):
|
||||
# Check last run status and timing
|
||||
last_run = await fetch_last_run(pool, source_id)
|
||||
|
||||
last_completed_at = None
|
||||
last_status = None
|
||||
retry_count = 0
|
||||
next_retry_at = None
|
||||
|
||||
if last_run:
|
||||
last_status = last_run["status"]
|
||||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||||
retry_count = last_run["retry_count"] or 0
|
||||
next_retry_at = last_run["next_retry_at"]
|
||||
|
||||
if not is_source_due(
|
||||
source_type=source_type,
|
||||
source_config=source_config,
|
||||
last_completed_at=last_completed_at,
|
||||
last_status=last_status,
|
||||
retry_count=retry_count,
|
||||
next_retry_at=next_retry_at,
|
||||
now=now,
|
||||
):
|
||||
skipped_not_due += 1
|
||||
continue
|
||||
|
||||
if not await check_rate_limit(rds, source_type):
|
||||
logger.warning(f"Rate limit hit for {source_type}")
|
||||
# Rate limit check
|
||||
if not await check_rate_limit(rds, source_type, now):
|
||||
logger.warning(
|
||||
"Rate limit hit for %s, skipping %s/%s",
|
||||
source_type, src["ticker"], src["source_name"],
|
||||
)
|
||||
skipped_rate_limit += 1
|
||||
continue
|
||||
|
||||
job = {
|
||||
"source_id": str(src["source_id"]),
|
||||
"company_id": str(src["company_id"]),
|
||||
"ticker": src["ticker"],
|
||||
"source_type": source_type,
|
||||
"source_name": src["source_name"],
|
||||
"config": dict(src["config"]) if src["config"] else {},
|
||||
"scheduled_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
||||
# Fetch company aliases for downstream entity matching
|
||||
aliases = await fetch_aliases_for_company(pool, src["company_id"])
|
||||
|
||||
job = build_job_payload(src, aliases, now)
|
||||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
|
||||
enqueued += 1
|
||||
|
||||
if enqueued:
|
||||
logger.info(f"Enqueued {enqueued} ingestion jobs")
|
||||
logger.debug(
|
||||
"Enqueued %s job for %s (%s)",
|
||||
source_type, src["ticker"], src["source_name"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
|
||||
enqueued, skipped_not_due, skipped_rate_limit, len(sources),
|
||||
)
|
||||
return enqueued
|
||||
|
||||
|
||||
async def main():
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await get_pg_pool(config)
|
||||
rds = get_redis(config)
|
||||
|
||||
logger.info("Scheduler started")
|
||||
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
|
||||
await schedule_cycle(pool, rds)
|
||||
await release_lock(rds, "scheduler_cycle")
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler cycle error: {e}")
|
||||
await asyncio.sleep(15)
|
||||
try:
|
||||
await schedule_cycle(pool, rds)
|
||||
finally:
|
||||
await release_lock(rds, "scheduler_cycle")
|
||||
except Exception:
|
||||
logger.exception("Scheduler cycle error")
|
||||
await asyncio.sleep(SCHEDULER_TICK)
|
||||
finally:
|
||||
await pool.close()
|
||||
await rds.close()
|
||||
|
||||
@@ -0,0 +1,342 @@
|
||||
"""Operational alerting for Stonks Oracle pipeline health.
|
||||
|
||||
Evaluates alert rules against PostgreSQL operational state and emits
|
||||
structured log events and Prometheus metrics when thresholds are breached.
|
||||
|
||||
Alert rules:
|
||||
- source_failures: sustained source retrieval failures per source
|
||||
- schema_failure_spike: extraction validation failure rate exceeds threshold
|
||||
- analytical_lag: lake publication has not completed within threshold
|
||||
- broker_issues: consecutive broker submission errors
|
||||
|
||||
Requirements: 12.3
|
||||
Design: Section 12 (Observability and Operations)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.config import AlertingConfig
|
||||
from services.shared.metrics import (
|
||||
ALERT_ACTIVE,
|
||||
ALERT_CHECK_DURATION,
|
||||
ALERTS_FIRED,
|
||||
ALERTS_RESOLVED,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("alerting")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Alert:
|
||||
"""A single alert instance."""
|
||||
|
||||
rule: str
|
||||
severity: str # "warning" | "critical"
|
||||
summary: str
|
||||
details: dict[str, Any] = field(default_factory=dict)
|
||||
fired_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertState:
|
||||
"""Tracks which rules are currently firing to detect transitions."""
|
||||
|
||||
active: dict[str, Alert] = field(default_factory=dict)
|
||||
|
||||
def fire(self, alert: Alert) -> bool:
|
||||
"""Record an alert firing. Returns True if this is a new firing."""
|
||||
key = f"{alert.rule}:{alert.details.get('key', '')}"
|
||||
is_new = key not in self.active
|
||||
self.active[key] = alert
|
||||
return is_new
|
||||
|
||||
def resolve(self, rule: str, key: str = "") -> bool:
|
||||
"""Resolve an alert. Returns True if it was previously active."""
|
||||
full_key = f"{rule}:{key}"
|
||||
if full_key in self.active:
|
||||
del self.active[full_key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_firing(self, rule: str, key: str = "") -> bool:
|
||||
return f"{rule}:{key}" in self.active
|
||||
|
||||
|
||||
async def check_source_failures(
|
||||
pool: asyncpg.Pool,
|
||||
config: AlertingConfig,
|
||||
) -> list[Alert]:
|
||||
"""Check for sources with sustained consecutive failures.
|
||||
|
||||
Queries ingestion_runs for sources where the last N runs all failed
|
||||
within the lookback window.
|
||||
"""
|
||||
rows = await pool.fetch(
|
||||
"""WITH recent_runs AS (
|
||||
SELECT source_id, status,
|
||||
ROW_NUMBER() OVER (PARTITION BY source_id ORDER BY started_at DESC) AS rn
|
||||
FROM ingestion_runs
|
||||
WHERE started_at >= NOW() - INTERVAL '1 hour' * $1
|
||||
),
|
||||
failure_streaks AS (
|
||||
SELECT source_id,
|
||||
COUNT(*) FILTER (WHERE status = 'failed') AS consecutive_failures,
|
||||
COUNT(*) AS total_runs
|
||||
FROM recent_runs
|
||||
WHERE rn <= $2
|
||||
GROUP BY source_id
|
||||
HAVING COUNT(*) FILTER (WHERE status = 'failed') = COUNT(*)
|
||||
AND COUNT(*) >= $2
|
||||
)
|
||||
SELECT fs.source_id, fs.consecutive_failures,
|
||||
s.source_type, s.source_name, c.ticker
|
||||
FROM failure_streaks fs
|
||||
JOIN sources s ON s.id = fs.source_id
|
||||
JOIN companies c ON c.id = s.company_id""",
|
||||
config.source_failure_window_hours,
|
||||
config.source_failure_threshold,
|
||||
)
|
||||
|
||||
alerts = []
|
||||
for row in rows:
|
||||
alerts.append(Alert(
|
||||
rule="source_failures",
|
||||
severity="warning",
|
||||
summary=(
|
||||
f"Source {row['source_name']} ({row['source_type']}) for "
|
||||
f"{row['ticker']} has {row['consecutive_failures']} consecutive failures"
|
||||
),
|
||||
details={
|
||||
"key": str(row["source_id"]),
|
||||
"source_id": str(row["source_id"]),
|
||||
"source_type": row["source_type"],
|
||||
"source_name": row["source_name"],
|
||||
"ticker": row["ticker"],
|
||||
"consecutive_failures": row["consecutive_failures"],
|
||||
},
|
||||
))
|
||||
return alerts
|
||||
|
||||
|
||||
async def check_schema_failure_spike(
|
||||
pool: asyncpg.Pool,
|
||||
config: AlertingConfig,
|
||||
) -> list[Alert]:
|
||||
"""Check if extraction schema validation failure rate exceeds threshold.
|
||||
|
||||
Queries model_performance_metrics for the recent window and computes
|
||||
the failure rate.
|
||||
"""
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE NOT success) AS failed
|
||||
FROM model_performance_metrics
|
||||
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1""",
|
||||
config.schema_failure_window_hours,
|
||||
)
|
||||
|
||||
if not row or row["total"] == 0:
|
||||
return []
|
||||
|
||||
total = row["total"]
|
||||
failed = row["failed"]
|
||||
failure_rate = failed / total
|
||||
|
||||
if failure_rate >= config.schema_failure_rate_threshold:
|
||||
return [Alert(
|
||||
rule="schema_failure_spike",
|
||||
severity="critical" if failure_rate >= 0.5 else "warning",
|
||||
summary=(
|
||||
f"Extraction schema failure rate is {failure_rate:.1%} "
|
||||
f"({failed}/{total}) in the last {config.schema_failure_window_hours}h"
|
||||
),
|
||||
details={
|
||||
"key": "global",
|
||||
"total_extractions": total,
|
||||
"failed_extractions": failed,
|
||||
"failure_rate": round(failure_rate, 4),
|
||||
"threshold": config.schema_failure_rate_threshold,
|
||||
"window_hours": config.schema_failure_window_hours,
|
||||
},
|
||||
)]
|
||||
return []
|
||||
|
||||
|
||||
async def check_analytical_lag(
|
||||
pool: asyncpg.Pool,
|
||||
config: AlertingConfig,
|
||||
) -> list[Alert]:
|
||||
"""Check if lake publication is lagging beyond threshold.
|
||||
|
||||
Looks at the audit_events table for the most recent successful
|
||||
lake_publish events per table, and alerts if any are stale.
|
||||
"""
|
||||
rows = await pool.fetch(
|
||||
"""SELECT
|
||||
details->>'table_name' AS table_name,
|
||||
MAX(created_at) AS last_publish
|
||||
FROM audit_events
|
||||
WHERE event_type = 'lake_publish'
|
||||
AND details->>'status' = 'success'
|
||||
AND details->>'table_name' IS NOT NULL
|
||||
GROUP BY details->>'table_name'
|
||||
HAVING MAX(created_at) < NOW() - INTERVAL '1 minute' * $1""",
|
||||
config.lake_lag_threshold_minutes,
|
||||
)
|
||||
|
||||
alerts = []
|
||||
now = datetime.now(timezone.utc)
|
||||
for row in rows:
|
||||
table_name = row["table_name"]
|
||||
last_publish = row["last_publish"]
|
||||
if last_publish.tzinfo is None:
|
||||
last_publish = last_publish.replace(tzinfo=timezone.utc)
|
||||
lag_minutes = (now - last_publish).total_seconds() / 60
|
||||
|
||||
alerts.append(Alert(
|
||||
rule="analytical_lag",
|
||||
severity="warning",
|
||||
summary=(
|
||||
f"Lake table '{table_name}' last published {lag_minutes:.0f}m ago "
|
||||
f"(threshold: {config.lake_lag_threshold_minutes}m)"
|
||||
),
|
||||
details={
|
||||
"key": table_name,
|
||||
"table_name": table_name,
|
||||
"last_publish": last_publish.isoformat(),
|
||||
"lag_minutes": round(lag_minutes, 1),
|
||||
"threshold_minutes": config.lake_lag_threshold_minutes,
|
||||
},
|
||||
))
|
||||
return alerts
|
||||
|
||||
|
||||
async def check_broker_issues(
|
||||
pool: asyncpg.Pool,
|
||||
config: AlertingConfig,
|
||||
) -> list[Alert]:
|
||||
"""Check for consecutive broker submission errors.
|
||||
|
||||
Queries order_events for recent broker-level errors (rejections,
|
||||
timeouts, connection failures) within the lookback window.
|
||||
"""
|
||||
rows = await pool.fetch(
|
||||
"""WITH recent_events AS (
|
||||
SELECT order_id, event_type, created_at,
|
||||
ROW_NUMBER() OVER (ORDER BY created_at DESC) AS rn
|
||||
FROM order_events
|
||||
WHERE created_at >= NOW() - INTERVAL '1 hour' * $1
|
||||
AND event_type IN ('broker_error', 'broker_timeout', 'connection_failed')
|
||||
)
|
||||
SELECT COUNT(*) AS error_count
|
||||
FROM recent_events
|
||||
WHERE rn <= $2""",
|
||||
config.broker_error_window_hours,
|
||||
config.broker_error_threshold,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
error_count = rows[0]["error_count"]
|
||||
if error_count >= config.broker_error_threshold:
|
||||
return [Alert(
|
||||
rule="broker_issues",
|
||||
severity="critical",
|
||||
summary=(
|
||||
f"{error_count} broker errors in the last "
|
||||
f"{config.broker_error_window_hours}h"
|
||||
),
|
||||
details={
|
||||
"key": "global",
|
||||
"error_count": error_count,
|
||||
"threshold": config.broker_error_threshold,
|
||||
"window_hours": config.broker_error_window_hours,
|
||||
},
|
||||
)]
|
||||
return []
|
||||
|
||||
|
||||
async def evaluate_alerts(
|
||||
pool: asyncpg.Pool,
|
||||
config: AlertingConfig,
|
||||
state: AlertState,
|
||||
) -> list[Alert]:
|
||||
"""Run all alert rules and return newly fired alerts.
|
||||
|
||||
Updates AlertState to track firing/resolved transitions and emits
|
||||
structured log events and Prometheus metrics for each transition.
|
||||
"""
|
||||
all_alerts: list[Alert] = []
|
||||
|
||||
with ALERT_CHECK_DURATION.time():
|
||||
# Collect alerts from all rules
|
||||
try:
|
||||
all_alerts.extend(await check_source_failures(pool, config))
|
||||
except Exception:
|
||||
logger.exception("Error checking source failures")
|
||||
|
||||
try:
|
||||
all_alerts.extend(await check_schema_failure_spike(pool, config))
|
||||
except Exception:
|
||||
logger.exception("Error checking schema failure spike")
|
||||
|
||||
try:
|
||||
all_alerts.extend(await check_analytical_lag(pool, config))
|
||||
except Exception:
|
||||
logger.exception("Error checking analytical lag")
|
||||
|
||||
try:
|
||||
all_alerts.extend(await check_broker_issues(pool, config))
|
||||
except Exception:
|
||||
logger.exception("Error checking broker issues")
|
||||
|
||||
# Track which rule+key combos are currently firing
|
||||
current_keys: set[str] = set()
|
||||
newly_fired: list[Alert] = []
|
||||
|
||||
for alert in all_alerts:
|
||||
key = f"{alert.rule}:{alert.details.get('key', '')}"
|
||||
current_keys.add(key)
|
||||
|
||||
if state.fire(alert):
|
||||
# New alert firing
|
||||
ALERTS_FIRED.labels(rule=alert.rule, severity=alert.severity).inc()
|
||||
ALERT_ACTIVE.labels(rule=alert.rule).set(1)
|
||||
newly_fired.append(alert)
|
||||
logger.warning(
|
||||
"ALERT FIRING: [%s] %s",
|
||||
alert.rule,
|
||||
alert.summary,
|
||||
extra={
|
||||
"alert_rule": alert.rule,
|
||||
"alert_severity": alert.severity,
|
||||
"alert_details": alert.details,
|
||||
},
|
||||
)
|
||||
|
||||
# Check for resolved alerts
|
||||
resolved_keys = set(state.active.keys()) - current_keys
|
||||
for key in resolved_keys:
|
||||
rule = key.split(":")[0]
|
||||
detail_key = key[len(rule) + 1:]
|
||||
if state.resolve(rule, detail_key):
|
||||
ALERTS_RESOLVED.labels(rule=rule).inc()
|
||||
# Only set gauge to 0 if no more alerts for this rule
|
||||
still_firing = any(k.startswith(f"{rule}:") for k in state.active)
|
||||
if not still_firing:
|
||||
ALERT_ACTIVE.labels(rule=rule).set(0)
|
||||
logger.info(
|
||||
"ALERT RESOLVED: [%s] key=%s",
|
||||
rule,
|
||||
detail_key,
|
||||
)
|
||||
|
||||
return newly_fired
|
||||
@@ -0,0 +1,493 @@
|
||||
"""Execution audit trail - records every step from recommendation to market outcome.
|
||||
|
||||
Writes structured audit events to the audit_events table so the full
|
||||
decision chain is traceable: recommendation → risk evaluation → order
|
||||
submission → broker response → fill/rejection/cancellation.
|
||||
|
||||
Each event captures the entity type, entity ID, event type, actor,
|
||||
and a JSONB data payload with stage-specific details.
|
||||
|
||||
Requirements: 8.3, 11.3
|
||||
Design: Section 4.9 (Broker Adapter), Section 6.1 (PostgreSQL audit_events)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger("audit")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event type constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Recommendation stage
|
||||
AUDIT_RECOMMENDATION_GENERATED = "recommendation.generated"
|
||||
AUDIT_RECOMMENDATION_SUPPRESSED = "recommendation.suppressed"
|
||||
|
||||
# Risk evaluation stage
|
||||
AUDIT_RISK_EVALUATED = "risk.evaluated"
|
||||
AUDIT_RISK_REJECTED = "risk.rejected"
|
||||
|
||||
# Order lifecycle
|
||||
AUDIT_ORDER_SUBMITTED = "order.submitted"
|
||||
AUDIT_ORDER_ACCEPTED = "order.accepted"
|
||||
AUDIT_ORDER_FILLED = "order.filled"
|
||||
AUDIT_ORDER_REJECTED = "order.rejected"
|
||||
AUDIT_ORDER_CANCELLED = "order.cancelled"
|
||||
AUDIT_ORDER_DUPLICATE = "order.duplicate_prevented"
|
||||
|
||||
# Position changes
|
||||
AUDIT_POSITION_OPENED = "position.opened"
|
||||
AUDIT_POSITION_CLOSED = "position.closed"
|
||||
AUDIT_POSITION_UPDATED = "position.updated"
|
||||
|
||||
# Trading mode changes
|
||||
AUDIT_TRADING_MODE_CHANGED = "trading.mode_changed"
|
||||
|
||||
# Operator approval workflow
|
||||
AUDIT_APPROVAL_REQUESTED = "approval.requested"
|
||||
AUDIT_APPROVAL_APPROVED = "approval.approved"
|
||||
AUDIT_APPROVAL_REJECTED = "approval.rejected"
|
||||
AUDIT_APPROVAL_EXPIRED = "approval.expired"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core audit writer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_AUDIT_EVENT = """
|
||||
INSERT INTO audit_events (id, event_type, entity_type, entity_id, actor, data, created_at)
|
||||
VALUES ($1::uuid, $2, $3, $4::uuid, $5, $6::jsonb, $7)
|
||||
"""
|
||||
|
||||
|
||||
async def record_audit_event(
|
||||
pool: asyncpg.Pool,
|
||||
event_type: str,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
data: dict[str, Any],
|
||||
actor: str = "system",
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Write a single audit event to PostgreSQL.
|
||||
|
||||
Returns the audit event UUID.
|
||||
"""
|
||||
event_id = str(uuid.uuid4())
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
await pool.execute(
|
||||
_INSERT_AUDIT_EVENT,
|
||||
event_id,
|
||||
event_type,
|
||||
entity_type,
|
||||
entity_id,
|
||||
actor,
|
||||
json.dumps(data, default=str),
|
||||
ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to write audit event %s for %s/%s",
|
||||
event_type, entity_type, entity_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return ""
|
||||
|
||||
return event_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience helpers for each execution stage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def audit_recommendation_generated(
|
||||
pool: asyncpg.Pool,
|
||||
recommendation_id: str,
|
||||
ticker: str,
|
||||
action: str,
|
||||
mode: str,
|
||||
confidence: float,
|
||||
evidence_count: int,
|
||||
suppressed: bool = False,
|
||||
) -> str:
|
||||
"""Record that a recommendation was generated."""
|
||||
event_type = AUDIT_RECOMMENDATION_SUPPRESSED if suppressed else AUDIT_RECOMMENDATION_GENERATED
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=event_type,
|
||||
entity_type="recommendation",
|
||||
entity_id=recommendation_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"mode": mode,
|
||||
"confidence": confidence,
|
||||
"evidence_count": evidence_count,
|
||||
"suppressed": suppressed,
|
||||
},
|
||||
actor="recommendation_worker",
|
||||
)
|
||||
|
||||
|
||||
async def audit_risk_evaluated(
|
||||
pool: asyncpg.Pool,
|
||||
evaluation_id: str,
|
||||
recommendation_id: str | None,
|
||||
ticker: str,
|
||||
eligible: bool,
|
||||
allowed_mode: str,
|
||||
rejection_reasons: list[str],
|
||||
check_count: int,
|
||||
) -> str:
|
||||
"""Record a risk evaluation result."""
|
||||
event_type = AUDIT_RISK_REJECTED if not eligible else AUDIT_RISK_EVALUATED
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=event_type,
|
||||
entity_type="risk_evaluation",
|
||||
entity_id=evaluation_id,
|
||||
data={
|
||||
"recommendation_id": recommendation_id,
|
||||
"ticker": ticker,
|
||||
"eligible": eligible,
|
||||
"allowed_mode": allowed_mode,
|
||||
"rejection_reasons": rejection_reasons,
|
||||
"check_count": check_count,
|
||||
},
|
||||
actor="risk_engine",
|
||||
)
|
||||
|
||||
|
||||
async def audit_order_submitted(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
ticker: str,
|
||||
side: str,
|
||||
quantity: float,
|
||||
order_type: str,
|
||||
idempotency_key: str,
|
||||
recommendation_id: str | None = None,
|
||||
evaluation_id: str | None = None,
|
||||
) -> str:
|
||||
"""Record that an order was submitted to the broker."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_ORDER_SUBMITTED,
|
||||
entity_type="order",
|
||||
entity_id=order_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"side": side,
|
||||
"quantity": quantity,
|
||||
"order_type": order_type,
|
||||
"idempotency_key": idempotency_key,
|
||||
"recommendation_id": recommendation_id,
|
||||
"evaluation_id": evaluation_id,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_order_filled(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
ticker: str,
|
||||
side: str,
|
||||
fill_quantity: float,
|
||||
fill_price: float | None,
|
||||
broker_order_id: str,
|
||||
) -> str:
|
||||
"""Record that an order was filled by the broker."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_ORDER_FILLED,
|
||||
entity_type="order",
|
||||
entity_id=order_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"side": side,
|
||||
"fill_quantity": fill_quantity,
|
||||
"fill_price": fill_price,
|
||||
"broker_order_id": broker_order_id,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_order_rejected(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
ticker: str,
|
||||
reason: str,
|
||||
source: str = "broker",
|
||||
) -> str:
|
||||
"""Record that an order was rejected (by risk engine or broker)."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_ORDER_REJECTED,
|
||||
entity_type="order",
|
||||
entity_id=order_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"reason": reason,
|
||||
"rejection_source": source,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_order_cancelled(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
ticker: str,
|
||||
broker_order_id: str,
|
||||
) -> str:
|
||||
"""Record that an order was cancelled."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_ORDER_CANCELLED,
|
||||
entity_type="order",
|
||||
entity_id=order_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"broker_order_id": broker_order_id,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_duplicate_prevented(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
ticker: str,
|
||||
idempotency_key: str,
|
||||
detected_via: str,
|
||||
) -> str:
|
||||
"""Record that a duplicate order was prevented."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_ORDER_DUPLICATE,
|
||||
entity_type="order",
|
||||
entity_id=order_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"idempotency_key": idempotency_key,
|
||||
"detected_via": detected_via,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_position_change(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
ticker: str,
|
||||
side: str,
|
||||
quantity_before: float,
|
||||
quantity_after: float,
|
||||
avg_entry_before: float,
|
||||
avg_entry_after: float,
|
||||
) -> str:
|
||||
"""Record a position change resulting from a fill."""
|
||||
if quantity_before == 0 and quantity_after > 0:
|
||||
event_type = AUDIT_POSITION_OPENED
|
||||
elif quantity_after == 0:
|
||||
event_type = AUDIT_POSITION_CLOSED
|
||||
else:
|
||||
event_type = AUDIT_POSITION_UPDATED
|
||||
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=event_type,
|
||||
entity_type="position",
|
||||
entity_id=order_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"side": side,
|
||||
"quantity_before": quantity_before,
|
||||
"quantity_after": quantity_after,
|
||||
"avg_entry_before": avg_entry_before,
|
||||
"avg_entry_after": avg_entry_after,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_approval_requested(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
ticker: str,
|
||||
side: str,
|
||||
quantity: float,
|
||||
estimated_value: float,
|
||||
recommendation_id: str | None = None,
|
||||
expires_at: str | None = None,
|
||||
) -> str:
|
||||
"""Record that an operator approval was requested for a live order."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_APPROVAL_REQUESTED,
|
||||
entity_type="approval",
|
||||
entity_id=approval_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"side": side,
|
||||
"quantity": quantity,
|
||||
"estimated_value": estimated_value,
|
||||
"recommendation_id": recommendation_id,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
actor="broker_service",
|
||||
)
|
||||
|
||||
|
||||
async def audit_approval_reviewed(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
ticker: str,
|
||||
approved: bool,
|
||||
reviewed_by: str = "operator",
|
||||
review_note: str = "",
|
||||
) -> str:
|
||||
"""Record that an operator reviewed an approval request."""
|
||||
event_type = AUDIT_APPROVAL_APPROVED if approved else AUDIT_APPROVAL_REJECTED
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=event_type,
|
||||
entity_type="approval",
|
||||
entity_id=approval_id,
|
||||
data={
|
||||
"ticker": ticker,
|
||||
"approved": approved,
|
||||
"reviewed_by": reviewed_by,
|
||||
"review_note": review_note,
|
||||
},
|
||||
actor=reviewed_by,
|
||||
)
|
||||
|
||||
|
||||
async def audit_approval_expired(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
ticker: str,
|
||||
) -> str:
|
||||
"""Record that an approval request expired without review."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_APPROVAL_EXPIRED,
|
||||
entity_type="approval",
|
||||
entity_id=approval_id,
|
||||
data={"ticker": ticker},
|
||||
actor="system",
|
||||
)
|
||||
|
||||
|
||||
async def audit_trading_mode_changed(
|
||||
pool: asyncpg.Pool,
|
||||
config_id: str,
|
||||
old_mode: str,
|
||||
new_mode: str,
|
||||
actor: str = "operator",
|
||||
) -> str:
|
||||
"""Record a trading mode change."""
|
||||
return await record_audit_event(
|
||||
pool,
|
||||
event_type=AUDIT_TRADING_MODE_CHANGED,
|
||||
entity_type="risk_config",
|
||||
entity_id=config_id,
|
||||
data={
|
||||
"old_mode": old_mode,
|
||||
"new_mode": new_mode,
|
||||
},
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query helpers for audit trail retrieval (Requirement 11.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FETCH_AUDIT_TRAIL_FOR_ORDER = """
|
||||
SELECT id, event_type, entity_type, entity_id, actor, data, created_at
|
||||
FROM audit_events
|
||||
WHERE entity_id = $1::uuid
|
||||
OR data->>'recommendation_id' = $2
|
||||
OR data->>'order_id' = $2
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
|
||||
_FETCH_AUDIT_TRAIL_BY_ENTITY = """
|
||||
SELECT id, event_type, entity_type, entity_id, actor, data, created_at
|
||||
FROM audit_events
|
||||
WHERE entity_type = $1 AND entity_id = $2::uuid
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
|
||||
_FETCH_FULL_EXECUTION_TRAIL = """
|
||||
SELECT id, event_type, entity_type, entity_id, actor, data, created_at
|
||||
FROM audit_events
|
||||
WHERE entity_id = $1::uuid
|
||||
OR entity_id IN (
|
||||
SELECT entity_id FROM audit_events
|
||||
WHERE data->>'recommendation_id' = $2
|
||||
)
|
||||
ORDER BY created_at ASC
|
||||
"""
|
||||
|
||||
|
||||
async def get_order_audit_trail(
|
||||
pool: asyncpg.Pool,
|
||||
order_id: str,
|
||||
recommendation_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch the full audit trail for an order, including related recommendation and risk events.
|
||||
|
||||
Returns events ordered chronologically so the full decision chain
|
||||
is visible: recommendation → risk → order → fill/reject.
|
||||
"""
|
||||
ref_id = recommendation_id or order_id
|
||||
rows = await pool.fetch(_FETCH_AUDIT_TRAIL_FOR_ORDER, order_id, ref_id)
|
||||
return [
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"event_type": row["event_type"],
|
||||
"entity_type": row["entity_type"],
|
||||
"entity_id": str(row["entity_id"]),
|
||||
"actor": row["actor"],
|
||||
"data": row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]),
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_entity_audit_trail(
|
||||
pool: asyncpg.Pool,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch all audit events for a specific entity."""
|
||||
rows = await pool.fetch(_FETCH_AUDIT_TRAIL_BY_ENTITY, entity_type, entity_id)
|
||||
return [
|
||||
{
|
||||
"id": str(row["id"]),
|
||||
"event_type": row["event_type"],
|
||||
"entity_type": row["entity_type"],
|
||||
"entity_id": str(row["entity_id"]),
|
||||
"actor": row["actor"],
|
||||
"data": row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]),
|
||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
@@ -43,6 +43,10 @@ class OllamaConfig:
|
||||
base_url: str = "http://localhost:11434"
|
||||
model: str = "llama3.1:8b"
|
||||
timeout: int = 120
|
||||
max_retries: int = 2
|
||||
retry_base_delay: float = 1.0
|
||||
retry_max_delay: float = 10.0
|
||||
retry_backoff_multiplier: float = 2.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -51,16 +55,82 @@ class TrinoConfig:
|
||||
port: int = 8080
|
||||
catalog: str = "lakehouse"
|
||||
schema: str = "stonks"
|
||||
iceberg_catalog: str = "iceberg"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketDataConfig:
|
||||
api_key: str = ""
|
||||
base_url: str = "https://api.polygon.io"
|
||||
provider: str = "polygon"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrokerConfig:
|
||||
mode: str = "paper" # paper | live
|
||||
provider: str = "alpaca"
|
||||
api_key: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetentionConfig:
|
||||
"""Default retention periods (days) per bucket class.
|
||||
|
||||
These can be overridden per-bucket via the retention_policies DB table.
|
||||
The cleanup_interval_hours controls how often the retention worker runs.
|
||||
"""
|
||||
raw_market_days: int = 90
|
||||
raw_news_days: int = 180
|
||||
raw_filings_days: int = 365
|
||||
normalized_days: int = 180
|
||||
llm_prompts_days: int = 365
|
||||
llm_results_days: int = 365
|
||||
lakehouse_days: int = 730
|
||||
audit_days: int = 730
|
||||
cleanup_interval_hours: int = 24
|
||||
batch_size: int = 1000
|
||||
|
||||
|
||||
# Map bucket names to RetentionConfig field names
|
||||
BUCKET_RETENTION_FIELDS: dict[str, str] = {
|
||||
"stonks-raw-market": "raw_market_days",
|
||||
"stonks-raw-news": "raw_news_days",
|
||||
"stonks-raw-filings": "raw_filings_days",
|
||||
"stonks-normalized": "normalized_days",
|
||||
"stonks-llm-prompts": "llm_prompts_days",
|
||||
"stonks-llm-results": "llm_results_days",
|
||||
"stonks-lakehouse": "lakehouse_days",
|
||||
"stonks-audit": "audit_days",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertingConfig:
|
||||
"""Thresholds for operational alerting rules.
|
||||
|
||||
Requirements: 12.3
|
||||
"""
|
||||
# Source failure alerting
|
||||
source_failure_threshold: int = 3 # consecutive failures before alert
|
||||
source_failure_window_hours: int = 6 # lookback window
|
||||
|
||||
# Schema/extraction failure spike
|
||||
schema_failure_rate_threshold: float = 0.3 # 30% failure rate triggers alert
|
||||
schema_failure_window_hours: int = 1
|
||||
|
||||
# Analytical (lake publication) lag
|
||||
lake_lag_threshold_minutes: int = 60 # minutes since last successful publish
|
||||
|
||||
# Broker issues
|
||||
broker_error_threshold: int = 3 # consecutive broker errors
|
||||
broker_error_window_hours: int = 1
|
||||
|
||||
# Evaluation interval
|
||||
check_interval_seconds: int = 120
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
postgres: PostgresConfig = field(default_factory=PostgresConfig)
|
||||
@@ -68,8 +138,12 @@ class AppConfig:
|
||||
minio: MinioConfig = field(default_factory=MinioConfig)
|
||||
ollama: OllamaConfig = field(default_factory=OllamaConfig)
|
||||
trino: TrinoConfig = field(default_factory=TrinoConfig)
|
||||
market_data: MarketDataConfig = field(default_factory=MarketDataConfig)
|
||||
broker: BrokerConfig = field(default_factory=BrokerConfig)
|
||||
retention: RetentionConfig = field(default_factory=RetentionConfig)
|
||||
alerting: AlertingConfig = field(default_factory=AlertingConfig)
|
||||
log_level: str = "INFO"
|
||||
json_logs: bool = True
|
||||
|
||||
|
||||
def load_config() -> AppConfig:
|
||||
@@ -98,18 +172,52 @@ def load_config() -> AppConfig:
|
||||
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
|
||||
model=os.getenv("OLLAMA_MODEL", "llama3.1:8b"),
|
||||
timeout=int(os.getenv("OLLAMA_TIMEOUT", "120")),
|
||||
max_retries=int(os.getenv("OLLAMA_MAX_RETRIES", "2")),
|
||||
retry_base_delay=float(os.getenv("OLLAMA_RETRY_BASE_DELAY", "1.0")),
|
||||
retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")),
|
||||
retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")),
|
||||
),
|
||||
trino=TrinoConfig(
|
||||
host=os.getenv("TRINO_HOST", "localhost"),
|
||||
port=int(os.getenv("TRINO_PORT", "8080")),
|
||||
catalog=os.getenv("TRINO_CATALOG", "lakehouse"),
|
||||
schema=os.getenv("TRINO_SCHEMA", "stonks"),
|
||||
iceberg_catalog=os.getenv("TRINO_ICEBERG_CATALOG", "iceberg"),
|
||||
),
|
||||
market_data=MarketDataConfig(
|
||||
api_key=os.getenv("MARKET_DATA_API_KEY", ""),
|
||||
base_url=os.getenv("MARKET_DATA_BASE_URL", "https://api.polygon.io"),
|
||||
provider=os.getenv("MARKET_DATA_PROVIDER", "polygon"),
|
||||
),
|
||||
broker=BrokerConfig(
|
||||
mode=os.getenv("BROKER_MODE", "paper"),
|
||||
provider=os.getenv("BROKER_PROVIDER", "alpaca"),
|
||||
api_key=os.getenv("BROKER_API_KEY", None),
|
||||
api_secret=os.getenv("BROKER_API_SECRET", None),
|
||||
base_url=os.getenv("BROKER_BASE_URL", None),
|
||||
),
|
||||
retention=RetentionConfig(
|
||||
raw_market_days=int(os.getenv("RETENTION_RAW_MARKET_DAYS", "90")),
|
||||
raw_news_days=int(os.getenv("RETENTION_RAW_NEWS_DAYS", "180")),
|
||||
raw_filings_days=int(os.getenv("RETENTION_RAW_FILINGS_DAYS", "365")),
|
||||
normalized_days=int(os.getenv("RETENTION_NORMALIZED_DAYS", "180")),
|
||||
llm_prompts_days=int(os.getenv("RETENTION_LLM_PROMPTS_DAYS", "365")),
|
||||
llm_results_days=int(os.getenv("RETENTION_LLM_RESULTS_DAYS", "365")),
|
||||
lakehouse_days=int(os.getenv("RETENTION_LAKEHOUSE_DAYS", "730")),
|
||||
audit_days=int(os.getenv("RETENTION_AUDIT_DAYS", "730")),
|
||||
cleanup_interval_hours=int(os.getenv("RETENTION_CLEANUP_INTERVAL_HOURS", "24")),
|
||||
batch_size=int(os.getenv("RETENTION_BATCH_SIZE", "1000")),
|
||||
),
|
||||
alerting=AlertingConfig(
|
||||
source_failure_threshold=int(os.getenv("ALERT_SOURCE_FAILURE_THRESHOLD", "3")),
|
||||
source_failure_window_hours=int(os.getenv("ALERT_SOURCE_FAILURE_WINDOW_HOURS", "6")),
|
||||
schema_failure_rate_threshold=float(os.getenv("ALERT_SCHEMA_FAILURE_RATE_THRESHOLD", "0.3")),
|
||||
schema_failure_window_hours=int(os.getenv("ALERT_SCHEMA_FAILURE_WINDOW_HOURS", "1")),
|
||||
lake_lag_threshold_minutes=int(os.getenv("ALERT_LAKE_LAG_THRESHOLD_MINUTES", "60")),
|
||||
broker_error_threshold=int(os.getenv("ALERT_BROKER_ERROR_THRESHOLD", "3")),
|
||||
broker_error_window_hours=int(os.getenv("ALERT_BROKER_ERROR_WINDOW_HOURS", "1")),
|
||||
check_interval_seconds=int(os.getenv("ALERT_CHECK_INTERVAL_SECONDS", "120")),
|
||||
),
|
||||
log_level=os.getenv("LOG_LEVEL", "INFO"),
|
||||
json_logs=os.getenv("JSON_LOGS", "true").lower() == "true",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Canonical URL normalization and content hashing utilities.
|
||||
|
||||
Provides consistent URL canonicalization and SHA-256 content hashing
|
||||
across all ingestion adapters and pipeline stages.
|
||||
|
||||
Requirements: 3.2, 3.3
|
||||
"""
|
||||
import hashlib
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse
|
||||
|
||||
|
||||
def normalize_url(url: str) -> str:
|
||||
"""Canonical URL normalization.
|
||||
|
||||
- Lowercases scheme and host
|
||||
- Strips fragments
|
||||
- Strips trailing slashes from path (preserves root "/")
|
||||
- Strips default ports (80, 443)
|
||||
- Sorts query parameters for deterministic comparison
|
||||
- Defaults scheme to https if missing
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
scheme = (parsed.scheme or "https").lower()
|
||||
netloc = (parsed.hostname or "").lower()
|
||||
if parsed.port and parsed.port not in (80, 443):
|
||||
netloc = f"{netloc}:{parsed.port}"
|
||||
path = parsed.path.rstrip("/") or "/"
|
||||
# Sort query params for deterministic ordering
|
||||
query = urlencode(sorted(parse_qsl(parsed.query)))
|
||||
normalized = f"{scheme}://{netloc}{path}"
|
||||
if query:
|
||||
normalized = f"{normalized}?{query}"
|
||||
return normalized
|
||||
|
||||
|
||||
def content_hash(data: bytes) -> str:
|
||||
"""Compute a stable SHA-256 hex digest for raw content bytes."""
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def content_hash_str(text: str, encoding: str = "utf-8") -> str:
|
||||
"""Compute a stable SHA-256 hex digest for a text string."""
|
||||
return hashlib.sha256(text.encode(encoding)).hexdigest()
|
||||
@@ -0,0 +1,134 @@
|
||||
"""Dead-letter queue (DLQ) support and replay tooling.
|
||||
|
||||
When a worker fails to process a job after exhausting retries, the job
|
||||
is pushed to a per-queue dead-letter list in Redis. Each DLQ entry
|
||||
wraps the original payload with failure metadata (error message,
|
||||
timestamp, attempt count) so operators can inspect and replay later.
|
||||
|
||||
Replay moves items from the DLQ back to the source queue for
|
||||
reprocessing.
|
||||
|
||||
Requirements: 12.1 (observability), design section 8 (data flows)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.shared.redis_keys import dlq_key, queue_key
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default max attempts before a job is dead-lettered
|
||||
DEFAULT_MAX_ATTEMPTS = 3
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def wrap_dlq_entry(
|
||||
original_payload: dict[str, Any],
|
||||
queue_name: str,
|
||||
error: str,
|
||||
attempt: int = 1,
|
||||
worker: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Wrap an original job payload with DLQ metadata."""
|
||||
return {
|
||||
"original_payload": original_payload,
|
||||
"queue": queue_name,
|
||||
"error": error,
|
||||
"attempt": attempt,
|
||||
"worker": worker,
|
||||
"dead_lettered_at": _now_iso(),
|
||||
}
|
||||
|
||||
|
||||
async def send_to_dlq(
|
||||
rds: aioredis.Redis,
|
||||
queue_name: str,
|
||||
original_payload: dict[str, Any],
|
||||
error: str,
|
||||
attempt: int = 1,
|
||||
worker: str = "",
|
||||
) -> None:
|
||||
"""Push a failed job to the dead-letter queue for *queue_name*."""
|
||||
entry = wrap_dlq_entry(original_payload, queue_name, error, attempt, worker)
|
||||
await rds.rpush(dlq_key(queue_name), json.dumps(entry, default=str))
|
||||
logger.warning(
|
||||
"Dead-lettered job on %s after %d attempts: %s",
|
||||
queue_name, attempt, error,
|
||||
extra={"queue": queue_name, "attempt": attempt},
|
||||
)
|
||||
|
||||
|
||||
async def dlq_length(rds: aioredis.Redis, queue_name: str) -> int:
|
||||
"""Return the number of items in the DLQ for *queue_name*."""
|
||||
return await rds.llen(dlq_key(queue_name))
|
||||
|
||||
|
||||
async def peek_dlq(
|
||||
rds: aioredis.Redis,
|
||||
queue_name: str,
|
||||
start: int = 0,
|
||||
count: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return DLQ entries without removing them (for inspection)."""
|
||||
raw_items = await rds.lrange(dlq_key(queue_name), start, start + count - 1)
|
||||
return [json.loads(item) for item in raw_items]
|
||||
|
||||
|
||||
async def replay_one(rds: aioredis.Redis, queue_name: str) -> dict[str, Any] | None:
|
||||
"""Pop the oldest DLQ entry and re-enqueue its original payload.
|
||||
|
||||
Returns the replayed DLQ entry, or None if the DLQ is empty.
|
||||
"""
|
||||
raw = await rds.lpop(dlq_key(queue_name))
|
||||
if raw is None:
|
||||
return None
|
||||
entry = json.loads(raw)
|
||||
original = entry.get("original_payload", entry)
|
||||
await rds.rpush(queue_key(queue_name), json.dumps(original, default=str))
|
||||
logger.info("Replayed 1 job from DLQ back to %s", queue_name)
|
||||
return entry
|
||||
|
||||
|
||||
async def replay_all(rds: aioredis.Redis, queue_name: str) -> int:
|
||||
"""Replay every item in the DLQ back to the source queue.
|
||||
|
||||
Returns the number of items replayed.
|
||||
"""
|
||||
count = 0
|
||||
while True:
|
||||
raw = await rds.lpop(dlq_key(queue_name))
|
||||
if raw is None:
|
||||
break
|
||||
entry = json.loads(raw)
|
||||
original = entry.get("original_payload", entry)
|
||||
await rds.rpush(queue_key(queue_name), json.dumps(original, default=str))
|
||||
count += 1
|
||||
if count:
|
||||
logger.info("Replayed %d jobs from DLQ back to %s", count, queue_name)
|
||||
return count
|
||||
|
||||
|
||||
async def purge_dlq(rds: aioredis.Redis, queue_name: str) -> int:
|
||||
"""Delete all items from the DLQ for *queue_name*. Returns count removed."""
|
||||
key = dlq_key(queue_name)
|
||||
length = await rds.llen(key)
|
||||
if length:
|
||||
await rds.delete(key)
|
||||
return length
|
||||
|
||||
|
||||
async def dlq_summary(rds: aioredis.Redis, queue_names: list[str]) -> dict[str, int]:
|
||||
"""Return a mapping of queue_name -> DLQ depth for the given queues."""
|
||||
result: dict[str, int] = {}
|
||||
for name in queue_names:
|
||||
result[name] = await rds.llen(dlq_key(name))
|
||||
return result
|
||||
@@ -0,0 +1,198 @@
|
||||
"""Cross-source deduplication for articles and filings.
|
||||
|
||||
Detects duplicate documents across different source types (news_api,
|
||||
filings_api, web_scrape) using a layered approach:
|
||||
|
||||
1. Redis fast-path: check content_hash and canonical_url markers for
|
||||
recently-seen documents (TTL-bounded, cheap).
|
||||
2. PostgreSQL fallback: query the documents table by canonical_url or
|
||||
content_hash for durable cross-source matching.
|
||||
|
||||
When a duplicate is detected the caller receives the existing document_id
|
||||
so it can link additional company mentions without re-inserting the document.
|
||||
|
||||
Requirements: 3.2, 3.3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.shared.content import content_hash_str, normalize_url
|
||||
from services.shared.redis_keys import DEDUPE_PREFIX
|
||||
|
||||
logger = logging.getLogger("dedupe")
|
||||
|
||||
# Redis TTL for dedupe markers (24 hours)
|
||||
DEDUPE_TTL_SECONDS: int = 86400
|
||||
|
||||
|
||||
def _url_dedupe_key(canonical_url: str) -> str:
|
||||
"""Build a Redis key for URL-based deduplication."""
|
||||
return f"{DEDUPE_PREFIX}:url:{content_hash_str(canonical_url)}"
|
||||
|
||||
|
||||
def _hash_dedupe_key(content_hash: str) -> str:
|
||||
"""Build a Redis key for content-hash-based deduplication."""
|
||||
return f"{DEDUPE_PREFIX}:{content_hash}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DedupeResult:
|
||||
"""Result of a deduplication check."""
|
||||
|
||||
is_duplicate: bool
|
||||
existing_document_id: str | None = None
|
||||
match_type: str | None = None # "content_hash" | "canonical_url" | None
|
||||
|
||||
|
||||
async def check_duplicate(
|
||||
pool: asyncpg.Pool,
|
||||
rds: aioredis.Redis,
|
||||
*,
|
||||
content_hash: str,
|
||||
url: str | None = None,
|
||||
canonical_url: str | None = None,
|
||||
) -> DedupeResult:
|
||||
"""Check whether a document is a duplicate across all source types.
|
||||
|
||||
Checks in order of cost:
|
||||
1. Redis content_hash marker (fast path)
|
||||
2. Redis canonical_url marker (fast path)
|
||||
3. PostgreSQL documents.content_hash (durable)
|
||||
4. PostgreSQL documents.canonical_url (cross-source)
|
||||
|
||||
Returns a DedupeResult indicating whether the document already exists.
|
||||
"""
|
||||
# Resolve canonical URL if only raw URL provided
|
||||
resolved_canonical = canonical_url or (normalize_url(url) if url else None)
|
||||
|
||||
# --- Redis fast path: content hash ---
|
||||
if content_hash:
|
||||
redis_key = _hash_dedupe_key(content_hash)
|
||||
cached_id = await rds.get(redis_key)
|
||||
if cached_id:
|
||||
logger.debug("Dedupe hit (redis content_hash) for %s", content_hash[:16])
|
||||
return DedupeResult(
|
||||
is_duplicate=True,
|
||||
existing_document_id=str(cached_id),
|
||||
match_type="content_hash",
|
||||
)
|
||||
|
||||
# --- Redis fast path: canonical URL ---
|
||||
if resolved_canonical:
|
||||
url_key = _url_dedupe_key(resolved_canonical)
|
||||
cached_id = await rds.get(url_key)
|
||||
if cached_id:
|
||||
logger.debug("Dedupe hit (redis canonical_url) for %s", resolved_canonical[:60])
|
||||
return DedupeResult(
|
||||
is_duplicate=True,
|
||||
existing_document_id=str(cached_id),
|
||||
match_type="canonical_url",
|
||||
)
|
||||
|
||||
# --- PostgreSQL fallback: content hash ---
|
||||
if content_hash:
|
||||
row = await pool.fetchrow(
|
||||
"SELECT id FROM documents WHERE content_hash = $1 LIMIT 1",
|
||||
content_hash,
|
||||
)
|
||||
if row:
|
||||
doc_id = str(row["id"])
|
||||
# Warm the Redis cache for future checks
|
||||
await _set_dedupe_markers(rds, content_hash, resolved_canonical, doc_id)
|
||||
logger.debug("Dedupe hit (pg content_hash) for %s", content_hash[:16])
|
||||
return DedupeResult(
|
||||
is_duplicate=True,
|
||||
existing_document_id=doc_id,
|
||||
match_type="content_hash",
|
||||
)
|
||||
|
||||
# --- PostgreSQL fallback: canonical URL ---
|
||||
if resolved_canonical:
|
||||
row = await pool.fetchrow(
|
||||
"SELECT id FROM documents WHERE canonical_url = $1 LIMIT 1",
|
||||
resolved_canonical,
|
||||
)
|
||||
if row:
|
||||
doc_id = str(row["id"])
|
||||
await _set_dedupe_markers(rds, content_hash, resolved_canonical, doc_id)
|
||||
logger.debug("Dedupe hit (pg canonical_url) for %s", resolved_canonical[:60])
|
||||
return DedupeResult(
|
||||
is_duplicate=True,
|
||||
existing_document_id=doc_id,
|
||||
match_type="canonical_url",
|
||||
)
|
||||
|
||||
return DedupeResult(is_duplicate=False)
|
||||
|
||||
|
||||
async def mark_as_seen(
|
||||
rds: aioredis.Redis,
|
||||
*,
|
||||
content_hash: str,
|
||||
canonical_url: str | None,
|
||||
document_id: str,
|
||||
) -> None:
|
||||
"""Mark a newly-persisted document in Redis for fast future dedupe checks."""
|
||||
await _set_dedupe_markers(rds, content_hash, canonical_url, document_id)
|
||||
|
||||
|
||||
async def _set_dedupe_markers(
|
||||
rds: aioredis.Redis,
|
||||
content_hash: str | None,
|
||||
canonical_url: str | None,
|
||||
document_id: str,
|
||||
) -> None:
|
||||
"""Set Redis dedupe markers for both content hash and canonical URL."""
|
||||
if content_hash:
|
||||
await rds.set(
|
||||
_hash_dedupe_key(content_hash), document_id, ex=DEDUPE_TTL_SECONDS
|
||||
)
|
||||
if canonical_url:
|
||||
await rds.set(
|
||||
_url_dedupe_key(canonical_url), document_id, ex=DEDUPE_TTL_SECONDS
|
||||
)
|
||||
|
||||
|
||||
async def dedupe_items(
|
||||
pool: asyncpg.Pool,
|
||||
rds: aioredis.Redis,
|
||||
items: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Partition a list of ingestion items into new and duplicate groups.
|
||||
|
||||
Each item is expected to have at least one of:
|
||||
- content_hash: SHA-256 of the raw content
|
||||
- url / canonical_url: the document URL
|
||||
|
||||
Returns (new_items, duplicate_items).
|
||||
"""
|
||||
new_items: list[dict[str, Any]] = []
|
||||
dup_items: list[dict[str, Any]] = []
|
||||
|
||||
for item in items:
|
||||
item_hash = item.get("content_hash", "")
|
||||
item_url = item.get("url") or item.get("link")
|
||||
item_canonical = item.get("canonical_url")
|
||||
|
||||
result = await check_duplicate(
|
||||
pool,
|
||||
rds,
|
||||
content_hash=item_hash,
|
||||
url=item_url,
|
||||
canonical_url=item_canonical,
|
||||
)
|
||||
|
||||
if result.is_duplicate:
|
||||
item["_dedupe_match_type"] = result.match_type
|
||||
item["_dedupe_existing_id"] = result.existing_document_id
|
||||
dup_items.append(item)
|
||||
else:
|
||||
new_items.append(item)
|
||||
|
||||
return new_items, dup_items
|
||||
@@ -0,0 +1,224 @@
|
||||
"""Structured logging and distributed tracing for all Stonks Oracle services.
|
||||
|
||||
Provides:
|
||||
- JSON-formatted structured log output for machine-parseable log aggregation
|
||||
- Trace context (trace_id, span_id, service) propagated through log records
|
||||
- Context manager for creating trace spans within a service
|
||||
- Helper to configure logging for any service worker or API
|
||||
|
||||
Requirements: 12.1
|
||||
Design: Section 12 (Observability and Operations)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trace context stored in contextvars for async-safe propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_trace_id: ContextVar[str] = ContextVar("trace_id", default="")
|
||||
_span_id: ContextVar[str] = ContextVar("span_id", default="")
|
||||
_service_name: ContextVar[str] = ContextVar("service_name", default="unknown")
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
return _trace_id.get()
|
||||
|
||||
|
||||
def get_span_id() -> str:
|
||||
return _span_id.get()
|
||||
|
||||
|
||||
def get_service_name() -> str:
|
||||
return _service_name.get()
|
||||
|
||||
|
||||
def set_trace_context(
|
||||
trace_id: str | None = None,
|
||||
span_id: str | None = None,
|
||||
service: str | None = None,
|
||||
) -> None:
|
||||
"""Set trace context for the current async task / thread."""
|
||||
if trace_id is not None:
|
||||
_trace_id.set(trace_id)
|
||||
if span_id is not None:
|
||||
_span_id.set(span_id)
|
||||
if service is not None:
|
||||
_service_name.set(service)
|
||||
|
||||
|
||||
def new_trace_id() -> str:
|
||||
return uuid.uuid4().hex[:16]
|
||||
|
||||
|
||||
def new_span_id() -> str:
|
||||
return uuid.uuid4().hex[:8]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Span context manager for tracing within a service
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Span:
|
||||
"""Lightweight span for distributed tracing.
|
||||
|
||||
Usage::
|
||||
|
||||
with Span("process_document", ticker="AAPL") as span:
|
||||
# ... do work ...
|
||||
span.set_attribute("doc_count", 5)
|
||||
|
||||
On exit the span logs its duration and attributes as a structured event.
|
||||
"""
|
||||
|
||||
def __init__(self, operation: str, **attributes: Any) -> None:
|
||||
self.operation = operation
|
||||
self.parent_span_id = get_span_id()
|
||||
self.span_id = new_span_id()
|
||||
self.trace_id = get_trace_id() or new_trace_id()
|
||||
self.attributes: dict[str, Any] = dict(attributes)
|
||||
self.start_time: float = 0.0
|
||||
self.duration_ms: float = 0.0
|
||||
self._token_trace: Any = None
|
||||
self._token_span: Any = None
|
||||
self._logger = logging.getLogger(get_service_name() or "tracing")
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
self.attributes[key] = value
|
||||
|
||||
def __enter__(self) -> Span:
|
||||
self.start_time = time.monotonic()
|
||||
self._token_trace = _trace_id.set(self.trace_id)
|
||||
self._token_span = _span_id.set(self.span_id)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.duration_ms = (time.monotonic() - self.start_time) * 1000
|
||||
status = "error" if exc_type else "ok"
|
||||
|
||||
self._logger.info(
|
||||
"span.end",
|
||||
extra={
|
||||
"span_operation": self.operation,
|
||||
"span_status": status,
|
||||
"span_duration_ms": round(self.duration_ms, 2),
|
||||
"span_parent_id": self.parent_span_id,
|
||||
"span_attributes": self.attributes,
|
||||
},
|
||||
)
|
||||
|
||||
# Restore parent span context
|
||||
if self._token_span is not None:
|
||||
_span_id.reset(self._token_span)
|
||||
if self._token_trace is not None:
|
||||
_trace_id.reset(self._token_trace)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON log formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Emit each log record as a single JSON line with trace context."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_entry: dict[str, Any] = {
|
||||
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
"service": get_service_name(),
|
||||
"trace_id": get_trace_id(),
|
||||
"span_id": get_span_id(),
|
||||
}
|
||||
|
||||
# Merge extra fields from Span or manual extra={} usage
|
||||
for key in (
|
||||
"span_operation", "span_status", "span_duration_ms",
|
||||
"span_parent_id", "span_attributes",
|
||||
"ticker", "document_id", "source_type", "job_id",
|
||||
"duration_ms", "error", "count",
|
||||
):
|
||||
val = getattr(record, key, None)
|
||||
if val is not None:
|
||||
log_entry[key] = val
|
||||
|
||||
if record.exc_info and record.exc_info[1]:
|
||||
log_entry["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
return json.dumps(log_entry, default=str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def setup_logging(
|
||||
service_name: str,
|
||||
level: str = "INFO",
|
||||
json_output: bool = True,
|
||||
) -> None:
|
||||
"""Configure structured logging for a service.
|
||||
|
||||
Call this once at service startup (before any log calls).
|
||||
|
||||
Args:
|
||||
service_name: Identifies this service in log output (e.g. "ingestion_worker").
|
||||
level: Log level string (DEBUG, INFO, WARNING, ERROR).
|
||||
json_output: If True, emit JSON lines. If False, use a human-readable format.
|
||||
"""
|
||||
_service_name.set(service_name)
|
||||
|
||||
root = logging.getLogger()
|
||||
root.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
# Remove existing handlers to avoid duplicate output
|
||||
root.handlers.clear()
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
if json_output:
|
||||
handler.setFormatter(JSONFormatter())
|
||||
else:
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(name)s (%(service)s) "
|
||||
"trace=%(trace_id)s span=%(span_id)s — %(message)s",
|
||||
defaults={"service": service_name, "trace_id": "", "span_id": ""},
|
||||
))
|
||||
root.addHandler(handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trace context propagation through job payloads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def inject_trace_context(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Inject current trace context into a job payload dict.
|
||||
|
||||
Call this before enqueuing a job to Redis so the downstream
|
||||
worker can continue the same trace.
|
||||
"""
|
||||
trace_id = get_trace_id()
|
||||
if trace_id:
|
||||
payload["_trace_id"] = trace_id
|
||||
return payload
|
||||
|
||||
|
||||
def extract_trace_context(payload: dict[str, Any]) -> None:
|
||||
"""Extract and set trace context from an incoming job payload.
|
||||
|
||||
Call this at the start of job processing. If no trace context
|
||||
is present, generates a new trace_id.
|
||||
"""
|
||||
trace_id = payload.get("_trace_id") or new_trace_id()
|
||||
set_trace_context(trace_id=trace_id, span_id=new_span_id())
|
||||
@@ -0,0 +1,696 @@
|
||||
"""Metadata persistence for market payloads, documents, and broker events.
|
||||
|
||||
Persists structured metadata records to PostgreSQL for all ingested artifacts.
|
||||
Each source type has its own persistence path:
|
||||
- market_api → market_snapshots table
|
||||
- news_api / filings_api / web_scrape → documents + document_company_mentions
|
||||
- broker → order_events or market_snapshots (for position/account snapshots)
|
||||
|
||||
Requirements: 3.3, 3.4, 8.3, 9.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.content import content_hash_str, normalize_url
|
||||
|
||||
logger = logging.getLogger("metadata")
|
||||
|
||||
|
||||
async def persist_market_snapshot(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
company_id: str | None,
|
||||
ticker: str,
|
||||
snapshot_type: str,
|
||||
data: dict[str, Any],
|
||||
source_provider: str,
|
||||
storage_ref: str,
|
||||
content_hash: str,
|
||||
captured_at: datetime | None = None,
|
||||
) -> str:
|
||||
"""Persist a market data snapshot to PostgreSQL.
|
||||
|
||||
Returns the snapshot row UUID.
|
||||
"""
|
||||
ts = captured_at or datetime.now(timezone.utc)
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO market_snapshots
|
||||
(company_id, ticker, snapshot_type, data, source_provider,
|
||||
captured_at, storage_ref, content_hash)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7, $8)
|
||||
RETURNING id""",
|
||||
company_id,
|
||||
ticker,
|
||||
snapshot_type,
|
||||
json.dumps(data),
|
||||
source_provider,
|
||||
ts,
|
||||
storage_ref,
|
||||
content_hash,
|
||||
)
|
||||
logger.debug("Persisted market snapshot %s for %s", row_id, ticker)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
async def persist_document(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
document_type: str,
|
||||
source_type: str,
|
||||
publisher: str,
|
||||
url: str | None,
|
||||
canonical_url: str | None,
|
||||
title: str,
|
||||
published_at: datetime | None,
|
||||
content_hash: str,
|
||||
storage_ref: str,
|
||||
language: str = "en",
|
||||
) -> str | None:
|
||||
"""Persist a document metadata record to PostgreSQL.
|
||||
|
||||
Returns the document row UUID, or None if a duplicate content_hash exists.
|
||||
"""
|
||||
exists = await pool.fetchval(
|
||||
"SELECT 1 FROM documents WHERE content_hash = $1", content_hash
|
||||
)
|
||||
if exists:
|
||||
return None
|
||||
|
||||
doc_id = await pool.fetchval(
|
||||
"""INSERT INTO documents
|
||||
(document_type, source_type, publisher, url, canonical_url,
|
||||
title, published_at, content_hash, raw_storage_ref,
|
||||
language, status)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, 'ingested')
|
||||
RETURNING id""",
|
||||
document_type,
|
||||
source_type,
|
||||
publisher,
|
||||
url,
|
||||
canonical_url,
|
||||
title,
|
||||
published_at,
|
||||
content_hash,
|
||||
storage_ref,
|
||||
language,
|
||||
)
|
||||
logger.debug("Persisted document %s (%s)", doc_id, title[:60] if title else "")
|
||||
return str(doc_id)
|
||||
|
||||
|
||||
async def update_document_parse_results(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
document_id: str,
|
||||
normalized_storage_ref: str | None,
|
||||
parser_output_ref: str | None,
|
||||
parse_quality_score: float,
|
||||
parse_confidence: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Update a document row with parser output references and quality scores.
|
||||
|
||||
Called after the parsing stage to persist normalized text location,
|
||||
structured parser output location, quality score, and confidence.
|
||||
|
||||
Requirements: 4.1, 4.3, 9.1
|
||||
"""
|
||||
await pool.execute(
|
||||
"""UPDATE documents SET
|
||||
normalized_storage_ref = $2,
|
||||
parser_output_ref = $3,
|
||||
parse_quality_score = $4,
|
||||
parse_confidence = $5,
|
||||
status = $6,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1""",
|
||||
document_id,
|
||||
normalized_storage_ref,
|
||||
parser_output_ref,
|
||||
parse_quality_score,
|
||||
parse_confidence,
|
||||
status,
|
||||
)
|
||||
logger.debug(
|
||||
"Updated document %s parse results: quality=%.2f confidence=%s status=%s",
|
||||
document_id, parse_quality_score, parse_confidence, status,
|
||||
)
|
||||
|
||||
|
||||
async def persist_document_company_mention(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
document_id: str,
|
||||
company_id: str,
|
||||
ticker: str,
|
||||
mention_type: str = "direct",
|
||||
confidence: float = 1.0,
|
||||
) -> str:
|
||||
"""Link a document to a company via document_company_mentions.
|
||||
|
||||
Returns the mention row UUID.
|
||||
"""
|
||||
mention_id = await pool.fetchval(
|
||||
"""INSERT INTO document_company_mentions
|
||||
(document_id, company_id, ticker, mention_type, confidence)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5)
|
||||
RETURNING id""",
|
||||
document_id,
|
||||
company_id,
|
||||
ticker,
|
||||
mention_type,
|
||||
confidence,
|
||||
)
|
||||
return str(mention_id)
|
||||
|
||||
|
||||
async def persist_broker_event(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
ticker: str,
|
||||
event_type: str,
|
||||
data: dict[str, Any],
|
||||
source_provider: str,
|
||||
storage_ref: str,
|
||||
content_hash: str,
|
||||
captured_at: datetime | None = None,
|
||||
) -> str:
|
||||
"""Persist a broker event snapshot to market_snapshots.
|
||||
|
||||
Broker position/account snapshots are stored as market_snapshots
|
||||
with snapshot_type prefixed by 'broker_' (e.g. broker_positions,
|
||||
broker_account, broker_orders).
|
||||
|
||||
Returns the snapshot row UUID.
|
||||
"""
|
||||
ts = captured_at or datetime.now(timezone.utc)
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO market_snapshots
|
||||
(ticker, snapshot_type, data, source_provider,
|
||||
captured_at, storage_ref, content_hash)
|
||||
VALUES ($1, $2, $3::jsonb, $4, $5, $6, $7)
|
||||
RETURNING id""",
|
||||
ticker,
|
||||
f"broker_{event_type}",
|
||||
json.dumps(data),
|
||||
source_provider,
|
||||
ts,
|
||||
storage_ref,
|
||||
content_hash,
|
||||
)
|
||||
logger.debug("Persisted broker event %s for %s", row_id, ticker)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
def _resolve_document_type(source_type: str) -> str:
|
||||
"""Map source_type to a document_type value."""
|
||||
mapping = {
|
||||
"news_api": "article",
|
||||
"filings_api": "filing",
|
||||
"web_scrape": "press_release",
|
||||
}
|
||||
return mapping.get(source_type, "article")
|
||||
|
||||
|
||||
def _extract_publisher(item: dict[str, Any]) -> str:
|
||||
"""Extract publisher name from an adapter item dict."""
|
||||
if item.get("publisher"):
|
||||
return str(item["publisher"])
|
||||
source = item.get("source")
|
||||
if isinstance(source, dict):
|
||||
return source.get("name", "")
|
||||
if source:
|
||||
return str(source)
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_published_at(item: dict[str, Any]) -> datetime | None:
|
||||
"""Parse published_at from various adapter item formats."""
|
||||
raw = item.get("publishedAt") or item.get("published_at")
|
||||
if not raw:
|
||||
return None
|
||||
if isinstance(raw, datetime):
|
||||
return raw
|
||||
try:
|
||||
return datetime.fromisoformat(str(raw).replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
async def persist_ingestion_items(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
source_type: str,
|
||||
ticker: str,
|
||||
company_id: str | None,
|
||||
items: list[dict[str, Any]],
|
||||
storage_ref: str,
|
||||
adapter_metadata: dict[str, Any],
|
||||
content_hash: str,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Route ingestion items to the correct persistence path.
|
||||
|
||||
Returns (new_item_count, list_of_new_ids).
|
||||
"""
|
||||
if source_type == "market_api":
|
||||
return await _persist_market_items(
|
||||
pool,
|
||||
ticker=ticker,
|
||||
company_id=company_id,
|
||||
items=items,
|
||||
storage_ref=storage_ref,
|
||||
provider=adapter_metadata.get("provider", "unknown"),
|
||||
content_hash=content_hash,
|
||||
)
|
||||
|
||||
if source_type == "broker":
|
||||
return await _persist_broker_items(
|
||||
pool,
|
||||
ticker=ticker,
|
||||
items=items,
|
||||
storage_ref=storage_ref,
|
||||
provider=adapter_metadata.get("provider", "unknown"),
|
||||
endpoint=adapter_metadata.get("endpoint", "positions"),
|
||||
content_hash=content_hash,
|
||||
)
|
||||
|
||||
# Document types: news_api, filings_api, web_scrape
|
||||
return await _persist_document_items(
|
||||
pool,
|
||||
source_type=source_type,
|
||||
ticker=ticker,
|
||||
company_id=company_id,
|
||||
items=items,
|
||||
storage_ref=storage_ref,
|
||||
)
|
||||
|
||||
|
||||
async def _persist_market_items(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
ticker: str,
|
||||
company_id: str | None,
|
||||
items: list[dict[str, Any]],
|
||||
storage_ref: str,
|
||||
provider: str,
|
||||
content_hash: str,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Persist market data items as market_snapshots rows."""
|
||||
ids: list[str] = []
|
||||
for item in items:
|
||||
item_hash = content_hash_str(json.dumps(item, sort_keys=True))
|
||||
# Skip duplicates
|
||||
exists = await pool.fetchval(
|
||||
"SELECT 1 FROM market_snapshots WHERE content_hash = $1", item_hash
|
||||
)
|
||||
if exists:
|
||||
continue
|
||||
|
||||
snapshot_type = _infer_market_snapshot_type(item)
|
||||
row_id = await persist_market_snapshot(
|
||||
pool,
|
||||
company_id=company_id,
|
||||
ticker=ticker,
|
||||
snapshot_type=snapshot_type,
|
||||
data=item,
|
||||
source_provider=provider,
|
||||
storage_ref=storage_ref,
|
||||
content_hash=item_hash,
|
||||
)
|
||||
ids.append(row_id)
|
||||
return len(ids), ids
|
||||
|
||||
|
||||
def _infer_market_snapshot_type(item: dict[str, Any]) -> str:
|
||||
"""Infer snapshot_type from market data item fields."""
|
||||
# Polygon aggregate bars have 'o', 'h', 'l', 'c' fields
|
||||
if all(k in item for k in ("o", "h", "l", "c")):
|
||||
return "bar"
|
||||
# Ticker details have 'market_cap' or 'sic_code'
|
||||
if "market_cap" in item or "sic_code" in item:
|
||||
return "ticker_details"
|
||||
# Quote snapshots
|
||||
if "ask" in item or "bid" in item:
|
||||
return "quote"
|
||||
return "snapshot"
|
||||
|
||||
|
||||
async def _persist_broker_items(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
ticker: str,
|
||||
items: list[dict[str, Any]],
|
||||
storage_ref: str,
|
||||
provider: str,
|
||||
endpoint: str,
|
||||
content_hash: str,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Persist broker fetch items as market_snapshots with broker_ prefix."""
|
||||
ids: list[str] = []
|
||||
for item in items:
|
||||
item_hash = content_hash_str(json.dumps(item, sort_keys=True))
|
||||
exists = await pool.fetchval(
|
||||
"SELECT 1 FROM market_snapshots WHERE content_hash = $1", item_hash
|
||||
)
|
||||
if exists:
|
||||
continue
|
||||
|
||||
row_id = await persist_broker_event(
|
||||
pool,
|
||||
ticker=ticker,
|
||||
event_type=endpoint,
|
||||
data=item,
|
||||
source_provider=provider,
|
||||
storage_ref=storage_ref,
|
||||
content_hash=item_hash,
|
||||
)
|
||||
ids.append(row_id)
|
||||
return len(ids), ids
|
||||
|
||||
|
||||
async def _persist_document_items(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
source_type: str,
|
||||
ticker: str,
|
||||
company_id: str | None,
|
||||
items: list[dict[str, Any]],
|
||||
storage_ref: str,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Persist document items (news, filings, web scrape) to documents table."""
|
||||
doc_type = _resolve_document_type(source_type)
|
||||
ids: list[str] = []
|
||||
|
||||
for item in items:
|
||||
item_hash = item.get("content_hash") or content_hash_str(
|
||||
json.dumps(item, sort_keys=True)
|
||||
)
|
||||
title = item.get("title", item.get("name", ""))
|
||||
url = item.get("url", item.get("link", ""))
|
||||
canonical_url = item.get("canonical_url") or (
|
||||
normalize_url(url) if url else None
|
||||
)
|
||||
published_at = _parse_published_at(item)
|
||||
publisher = _extract_publisher(item)
|
||||
|
||||
doc_id = await persist_document(
|
||||
pool,
|
||||
document_type=doc_type,
|
||||
source_type=source_type,
|
||||
publisher=publisher,
|
||||
url=url or None,
|
||||
canonical_url=canonical_url,
|
||||
title=title,
|
||||
published_at=published_at,
|
||||
content_hash=item_hash,
|
||||
storage_ref=storage_ref,
|
||||
)
|
||||
if doc_id is None:
|
||||
continue
|
||||
|
||||
# Link document to company if we have a company_id
|
||||
if company_id:
|
||||
await persist_document_company_mention(
|
||||
pool,
|
||||
document_id=doc_id,
|
||||
company_id=company_id,
|
||||
ticker=ticker,
|
||||
)
|
||||
|
||||
ids.append(doc_id)
|
||||
|
||||
return len(ids), ids
|
||||
|
||||
|
||||
# --- Retry and failure tracking (Requirement 3.4) ---
|
||||
|
||||
# Backoff constants — match scheduler defaults for consistency
|
||||
RETRY_BACKOFF_BASE: int = 60
|
||||
RETRY_BACKOFF_MAX: int = 3600
|
||||
RETRY_MAX_COUNT: int = 10
|
||||
|
||||
|
||||
def compute_next_retry_at(
|
||||
retry_count: int,
|
||||
now: datetime | None = None,
|
||||
base: int = RETRY_BACKOFF_BASE,
|
||||
cap: int = RETRY_BACKOFF_MAX,
|
||||
) -> datetime:
|
||||
"""Compute the next eligible retry time using exponential backoff.
|
||||
|
||||
Args:
|
||||
retry_count: Current retry count (before incrementing).
|
||||
now: Reference timestamp (defaults to UTC now).
|
||||
base: Base delay in seconds.
|
||||
cap: Maximum delay in seconds.
|
||||
|
||||
Returns:
|
||||
Datetime of the next eligible retry.
|
||||
"""
|
||||
ts = now or datetime.now(timezone.utc)
|
||||
delay = min(base * (2 ** min(retry_count, 8)), cap)
|
||||
return ts + timedelta(seconds=delay)
|
||||
|
||||
|
||||
async def get_source_retry_count(
|
||||
pool: asyncpg.Pool,
|
||||
source_id: str,
|
||||
) -> int:
|
||||
"""Return the retry count from the most recent failed run for a source.
|
||||
|
||||
If the last run succeeded or no runs exist, returns 0.
|
||||
"""
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT status, retry_count
|
||||
FROM ingestion_runs
|
||||
WHERE source_id = $1::uuid
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1""",
|
||||
source_id,
|
||||
)
|
||||
if row and row["status"] == "failed":
|
||||
return row["retry_count"] or 0
|
||||
return 0
|
||||
|
||||
|
||||
async def record_retrieval_failure(
|
||||
pool: asyncpg.Pool,
|
||||
run_id: str,
|
||||
source_id: str,
|
||||
error_message: str,
|
||||
retry_count: int | None = None,
|
||||
now: datetime | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Record a source retrieval failure with retry policy state.
|
||||
|
||||
Updates the ingestion_runs row with:
|
||||
- error_message: the failure reason
|
||||
- retry_count: incremented from the previous failed run (or provided)
|
||||
- next_retry_at: computed via exponential backoff
|
||||
- status: 'failed'
|
||||
|
||||
If retry_count is not provided, it is looked up from the most recent
|
||||
failed run for the same source and incremented.
|
||||
|
||||
Returns a dict with the recorded retry state for observability.
|
||||
|
||||
Requirement 3.4
|
||||
"""
|
||||
ts = now or datetime.now(timezone.utc)
|
||||
|
||||
if retry_count is None:
|
||||
prev_count = await get_source_retry_count(pool, source_id)
|
||||
retry_count = prev_count + 1
|
||||
else:
|
||||
retry_count = retry_count + 1
|
||||
|
||||
next_retry = compute_next_retry_at(retry_count - 1, now=ts)
|
||||
exhausted = retry_count >= RETRY_MAX_COUNT
|
||||
|
||||
await pool.execute(
|
||||
"""UPDATE ingestion_runs
|
||||
SET status = 'failed',
|
||||
error_message = $2,
|
||||
retry_count = $3,
|
||||
next_retry_at = $4,
|
||||
completed_at = $5
|
||||
WHERE id = $1""",
|
||||
run_id,
|
||||
error_message,
|
||||
retry_count,
|
||||
next_retry,
|
||||
ts,
|
||||
)
|
||||
|
||||
state = {
|
||||
"run_id": run_id,
|
||||
"source_id": source_id,
|
||||
"retry_count": retry_count,
|
||||
"next_retry_at": next_retry.isoformat(),
|
||||
"exhausted": exhausted,
|
||||
"error_message": error_message,
|
||||
}
|
||||
|
||||
if exhausted:
|
||||
logger.warning(
|
||||
"Source %s exhausted retries (%d/%d): %s",
|
||||
source_id, retry_count, RETRY_MAX_COUNT, error_message,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Source %s failed (retry %d/%d), next retry at %s: %s",
|
||||
source_id, retry_count, RETRY_MAX_COUNT,
|
||||
next_retry.isoformat(), error_message,
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def persist_document_intelligence(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
document_id: str,
|
||||
summary: str,
|
||||
macro_themes: list[str],
|
||||
novelty_score: float,
|
||||
source_credibility: float,
|
||||
extraction_warnings: list[str],
|
||||
confidence: float,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
prompt_version: str,
|
||||
schema_version: str,
|
||||
raw_output_ref: str | None = None,
|
||||
prompt_ref: str | None = None,
|
||||
validation_status: str = "valid",
|
||||
validation_errors: list[str] | None = None,
|
||||
retry_count: int = 0,
|
||||
) -> str:
|
||||
"""Persist a document intelligence record to PostgreSQL.
|
||||
|
||||
Returns the intelligence row UUID.
|
||||
|
||||
Requirements: 5.3, 5.4, 9.2
|
||||
"""
|
||||
intel_id = await pool.fetchval(
|
||||
"""INSERT INTO document_intelligence
|
||||
(document_id, summary, macro_themes, novelty_score,
|
||||
source_credibility, extraction_warnings, confidence,
|
||||
model_provider, model_name, prompt_version, schema_version,
|
||||
raw_output_ref, prompt_ref, validation_status,
|
||||
validation_errors, retry_count)
|
||||
VALUES ($1::uuid, $2, $3::jsonb, $4, $5, $6::jsonb, $7,
|
||||
$8, $9, $10, $11, $12, $13, $14, $15::jsonb, $16)
|
||||
RETURNING id""",
|
||||
document_id,
|
||||
summary,
|
||||
json.dumps(macro_themes),
|
||||
novelty_score,
|
||||
source_credibility,
|
||||
json.dumps(extraction_warnings),
|
||||
confidence,
|
||||
model_provider,
|
||||
model_name,
|
||||
prompt_version,
|
||||
schema_version,
|
||||
raw_output_ref,
|
||||
prompt_ref,
|
||||
validation_status,
|
||||
json.dumps(validation_errors or []),
|
||||
retry_count,
|
||||
)
|
||||
logger.debug("Persisted document intelligence %s for doc %s", intel_id, document_id)
|
||||
return str(intel_id)
|
||||
|
||||
|
||||
async def persist_document_impact(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
intelligence_id: str,
|
||||
company_id: str,
|
||||
ticker: str,
|
||||
relevance: float,
|
||||
sentiment: str,
|
||||
impact_score: float,
|
||||
impact_horizon: str,
|
||||
catalyst_type: str,
|
||||
key_facts: list[str],
|
||||
risks: list[str],
|
||||
evidence_spans: list[str],
|
||||
) -> str:
|
||||
"""Persist a per-company impact record linked to a document intelligence row.
|
||||
|
||||
Returns the impact record UUID.
|
||||
|
||||
Requirements: 5.3, 5.5, 9.2
|
||||
"""
|
||||
impact_id = await pool.fetchval(
|
||||
"""INSERT INTO document_impact_records
|
||||
(intelligence_id, company_id, ticker, relevance, sentiment,
|
||||
impact_score, impact_horizon, catalyst_type,
|
||||
key_facts, risks, evidence_spans)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6, $7, $8,
|
||||
$9::jsonb, $10::jsonb, $11::jsonb)
|
||||
RETURNING id""",
|
||||
intelligence_id,
|
||||
company_id,
|
||||
ticker,
|
||||
relevance,
|
||||
sentiment,
|
||||
impact_score,
|
||||
impact_horizon,
|
||||
catalyst_type,
|
||||
json.dumps(key_facts),
|
||||
json.dumps(risks),
|
||||
json.dumps(evidence_spans),
|
||||
)
|
||||
logger.debug("Persisted impact record %s for %s", impact_id, ticker)
|
||||
return str(impact_id)
|
||||
|
||||
|
||||
async def update_document_status(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
document_id: str,
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Update the status field on a document row.
|
||||
|
||||
Used to advance documents through the pipeline: ingested → parsed → extracted → failed.
|
||||
|
||||
Requirements: 5.4
|
||||
"""
|
||||
await pool.execute(
|
||||
"""UPDATE documents SET status = $2, updated_at = NOW() WHERE id = $1::uuid""",
|
||||
document_id,
|
||||
status,
|
||||
)
|
||||
logger.debug("Updated document %s status to %s", document_id, status)
|
||||
|
||||
|
||||
async def reset_source_retry_state(
|
||||
pool: asyncpg.Pool,
|
||||
source_id: str,
|
||||
) -> None:
|
||||
"""Reset retry state for a source after a successful run.
|
||||
|
||||
Sets retry_count=0 and next_retry_at=NULL on the most recent run.
|
||||
Called after a successful ingestion to clear any accumulated backoff.
|
||||
"""
|
||||
await pool.execute(
|
||||
"""UPDATE ingestion_runs
|
||||
SET retry_count = 0, next_retry_at = NULL
|
||||
WHERE id = (
|
||||
SELECT id FROM ingestion_runs
|
||||
WHERE source_id = $1::uuid
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1
|
||||
)""",
|
||||
source_id,
|
||||
)
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Prometheus metrics for all Stonks Oracle pipeline stages.
|
||||
|
||||
Provides counters, histograms, and gauges covering:
|
||||
- Ingestion: items fetched, new items, errors, adapter latency
|
||||
- Parsing: documents parsed, quality scores, low-quality flags
|
||||
- Extraction: attempts, successes, failures, latency, confidence, retries
|
||||
- Aggregation: trend windows computed, signal counts, contradiction scores
|
||||
- Lake publication: facts published per table, write latency
|
||||
- Trading: orders submitted, rejected, filled, risk evaluations
|
||||
|
||||
Requirements: 12.1, 12.2
|
||||
Design: Section 12 (Observability and Operations)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SERVICE_INFO = Info("stonks_oracle", "Stonks Oracle service metadata")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ingestion metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INGESTION_JOBS_TOTAL = Counter(
|
||||
"stonks_ingestion_jobs_total",
|
||||
"Total ingestion jobs processed",
|
||||
["source_type", "status"],
|
||||
)
|
||||
|
||||
INGESTION_ITEMS_FETCHED = Counter(
|
||||
"stonks_ingestion_items_fetched_total",
|
||||
"Total items fetched from external sources",
|
||||
["source_type"],
|
||||
)
|
||||
|
||||
INGESTION_ITEMS_NEW = Counter(
|
||||
"stonks_ingestion_items_new_total",
|
||||
"New (non-duplicate) items ingested",
|
||||
["source_type"],
|
||||
)
|
||||
|
||||
INGESTION_ITEMS_DEDUPED = Counter(
|
||||
"stonks_ingestion_items_deduped_total",
|
||||
"Items skipped due to deduplication",
|
||||
["source_type"],
|
||||
)
|
||||
|
||||
INGESTION_ERRORS = Counter(
|
||||
"stonks_ingestion_errors_total",
|
||||
"Ingestion errors by source type",
|
||||
["source_type"],
|
||||
)
|
||||
|
||||
INGESTION_ADAPTER_DURATION = Histogram(
|
||||
"stonks_ingestion_adapter_duration_seconds",
|
||||
"Adapter fetch latency in seconds",
|
||||
["source_type"],
|
||||
buckets=(0.1, 0.5, 1, 2, 5, 10, 30, 60),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PARSE_JOBS_TOTAL = Counter(
|
||||
"stonks_parse_jobs_total",
|
||||
"Total parse jobs processed",
|
||||
["status"],
|
||||
)
|
||||
|
||||
PARSE_QUALITY_SCORE = Histogram(
|
||||
"stonks_parse_quality_score",
|
||||
"Distribution of parser quality scores",
|
||||
buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
|
||||
)
|
||||
|
||||
PARSE_LOW_QUALITY_TOTAL = Counter(
|
||||
"stonks_parse_low_quality_total",
|
||||
"Documents flagged as low quality by the parser",
|
||||
)
|
||||
|
||||
PARSE_DURATION = Histogram(
|
||||
"stonks_parse_duration_seconds",
|
||||
"Parse job duration in seconds",
|
||||
buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extraction metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EXTRACTION_JOBS_TOTAL = Counter(
|
||||
"stonks_extraction_jobs_total",
|
||||
"Total extraction jobs processed",
|
||||
["status"],
|
||||
)
|
||||
|
||||
EXTRACTION_ATTEMPTS = Counter(
|
||||
"stonks_extraction_attempts_total",
|
||||
"Total Ollama extraction attempts (including retries)",
|
||||
)
|
||||
|
||||
EXTRACTION_RETRIES = Counter(
|
||||
"stonks_extraction_retries_total",
|
||||
"Extraction retry count",
|
||||
)
|
||||
|
||||
EXTRACTION_DURATION = Histogram(
|
||||
"stonks_extraction_duration_seconds",
|
||||
"Extraction total duration in seconds",
|
||||
buckets=(1, 2, 5, 10, 20, 30, 60, 120),
|
||||
)
|
||||
|
||||
EXTRACTION_CONFIDENCE = Histogram(
|
||||
"stonks_extraction_confidence",
|
||||
"Distribution of extraction confidence scores",
|
||||
buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
|
||||
)
|
||||
|
||||
EXTRACTION_VALIDATION_ERRORS = Counter(
|
||||
"stonks_extraction_validation_errors_total",
|
||||
"Total validation errors across extractions",
|
||||
)
|
||||
|
||||
EXTRACTION_TOKEN_ESTIMATE = Counter(
|
||||
"stonks_extraction_tokens_total",
|
||||
"Estimated token usage",
|
||||
["direction"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Aggregation metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AGGREGATION_WINDOWS_COMPUTED = Counter(
|
||||
"stonks_aggregation_windows_total",
|
||||
"Trend windows computed",
|
||||
["window"],
|
||||
)
|
||||
|
||||
AGGREGATION_SIGNALS_PROCESSED = Counter(
|
||||
"stonks_aggregation_signals_total",
|
||||
"Signals processed during aggregation",
|
||||
["window"],
|
||||
)
|
||||
|
||||
AGGREGATION_CONTRADICTION_SCORE = Histogram(
|
||||
"stonks_aggregation_contradiction_score",
|
||||
"Distribution of contradiction scores in trend windows",
|
||||
buckets=(0.0, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0),
|
||||
)
|
||||
|
||||
AGGREGATION_DURATION = Histogram(
|
||||
"stonks_aggregation_duration_seconds",
|
||||
"Aggregation job duration in seconds",
|
||||
["window"],
|
||||
buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recommendation metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RECOMMENDATION_GENERATED = Counter(
|
||||
"stonks_recommendations_total",
|
||||
"Recommendations generated",
|
||||
["action", "mode"],
|
||||
)
|
||||
|
||||
RECOMMENDATION_SUPPRESSED = Counter(
|
||||
"stonks_recommendations_suppressed_total",
|
||||
"Recommendations suppressed due to low data quality",
|
||||
)
|
||||
|
||||
RECOMMENDATION_CONFIDENCE = Histogram(
|
||||
"stonks_recommendation_confidence",
|
||||
"Distribution of recommendation confidence scores",
|
||||
buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lake publication metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LAKE_FACTS_PUBLISHED = Counter(
|
||||
"stonks_lake_facts_published_total",
|
||||
"Analytical facts published to the lakehouse",
|
||||
["table_name"],
|
||||
)
|
||||
|
||||
LAKE_PUBLISH_DURATION = Histogram(
|
||||
"stonks_lake_publish_duration_seconds",
|
||||
"Lake publication write latency in seconds",
|
||||
["table_name"],
|
||||
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5),
|
||||
)
|
||||
|
||||
LAKE_PUBLISH_ERRORS = Counter(
|
||||
"stonks_lake_publish_errors_total",
|
||||
"Lake publication errors",
|
||||
["table_name"],
|
||||
)
|
||||
|
||||
LAKE_PUBLISH_BYTES = Counter(
|
||||
"stonks_lake_publish_bytes_total",
|
||||
"Total bytes written to the lakehouse",
|
||||
["table_name"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trading / broker metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ORDERS_SUBMITTED = Counter(
|
||||
"stonks_orders_submitted_total",
|
||||
"Orders submitted to broker",
|
||||
["side", "order_type", "mode"],
|
||||
)
|
||||
|
||||
ORDERS_REJECTED = Counter(
|
||||
"stonks_orders_rejected_total",
|
||||
"Orders rejected before broker submission",
|
||||
["reason_category"],
|
||||
)
|
||||
|
||||
ORDERS_FILLED = Counter(
|
||||
"stonks_orders_filled_total",
|
||||
"Orders filled by broker",
|
||||
["side"],
|
||||
)
|
||||
|
||||
ORDERS_DUPLICATES_PREVENTED = Counter(
|
||||
"stonks_orders_duplicates_prevented_total",
|
||||
"Duplicate orders prevented by idempotency checks",
|
||||
["detected_via"],
|
||||
)
|
||||
|
||||
RISK_EVALUATIONS_TOTAL = Counter(
|
||||
"stonks_risk_evaluations_total",
|
||||
"Risk evaluations performed",
|
||||
["result"],
|
||||
)
|
||||
|
||||
RISK_CHECK_FAILURES = Counter(
|
||||
"stonks_risk_check_failures_total",
|
||||
"Individual risk check failures",
|
||||
["check_name"],
|
||||
)
|
||||
|
||||
POSITIONS_SYNCED = Counter(
|
||||
"stonks_positions_synced_total",
|
||||
"Position sync operations completed",
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active gauges
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ACTIVE_JOBS = Gauge(
|
||||
"stonks_active_jobs",
|
||||
"Currently processing jobs by stage",
|
||||
["stage"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alerting metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALERTS_FIRED = Counter(
|
||||
"stonks_alerts_fired_total",
|
||||
"Total alerts fired by rule",
|
||||
["rule", "severity"],
|
||||
)
|
||||
|
||||
ALERTS_RESOLVED = Counter(
|
||||
"stonks_alerts_resolved_total",
|
||||
"Total alerts resolved by rule",
|
||||
["rule"],
|
||||
)
|
||||
|
||||
ALERT_CHECK_DURATION = Histogram(
|
||||
"stonks_alert_check_duration_seconds",
|
||||
"Duration of alert evaluation cycle",
|
||||
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5),
|
||||
)
|
||||
|
||||
ALERT_ACTIVE = Gauge(
|
||||
"stonks_alert_active",
|
||||
"Whether an alert rule is currently firing (1) or resolved (0)",
|
||||
["rule"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dead-letter queue metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DLQ_ITEMS_TOTAL = Counter(
|
||||
"stonks_dlq_items_total",
|
||||
"Jobs sent to dead-letter queues",
|
||||
["queue"],
|
||||
)
|
||||
|
||||
DLQ_REPLAYED_TOTAL = Counter(
|
||||
"stonks_dlq_replayed_total",
|
||||
"Jobs replayed from dead-letter queues",
|
||||
["queue"],
|
||||
)
|
||||
|
||||
DLQ_DEPTH = Gauge(
|
||||
"stonks_dlq_depth",
|
||||
"Current dead-letter queue depth",
|
||||
["queue"],
|
||||
)
|
||||
@@ -46,6 +46,15 @@ def retry_key(job_id: str) -> str:
|
||||
return f"{RETRY_PREFIX}:{job_id}"
|
||||
|
||||
|
||||
# Dead-letter queues
|
||||
DLQ_PREFIX = f"{PREFIX}:dlq"
|
||||
|
||||
|
||||
def dlq_key(queue_name: str) -> str:
|
||||
"""Return the dead-letter queue key for a given source queue."""
|
||||
return f"{DLQ_PREFIX}:{queue_name}"
|
||||
|
||||
|
||||
# --- Queue names ---
|
||||
QUEUE_INGESTION = "ingestion"
|
||||
QUEUE_PARSING = "parsing"
|
||||
@@ -54,3 +63,4 @@ QUEUE_AGGREGATION = "aggregation"
|
||||
QUEUE_RECOMMENDATION = "recommendation"
|
||||
QUEUE_LAKE_PUBLISH = "lake_publish"
|
||||
QUEUE_TRADE = "trade"
|
||||
QUEUE_BROKER = "broker_orders"
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
"""Data retention and lifecycle controls for raw and derived artifacts.
|
||||
|
||||
Provides configurable per-bucket retention policies, expired object cleanup
|
||||
from MinIO, and expired metadata cleanup from PostgreSQL.
|
||||
|
||||
Requirements: N3 (preserve source metadata, access policy, and retention policy)
|
||||
Design ref: Section 5.2 (MinIO bucket layout), Section 10 (Reliability and Safety)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.shared.config import BUCKET_RETENTION_FIELDS, RetentionConfig
|
||||
from services.shared.storage import ALL_BUCKETS
|
||||
|
||||
logger = logging.getLogger("retention")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetentionPolicy:
|
||||
"""Resolved retention policy for a single bucket."""
|
||||
bucket_name: str
|
||||
retention_days: int
|
||||
archive_before_delete: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CleanupResult:
|
||||
"""Result of a single bucket cleanup run."""
|
||||
bucket_name: str
|
||||
objects_scanned: int = 0
|
||||
objects_deleted: int = 0
|
||||
bytes_freed: int = 0
|
||||
db_rows_deleted: int = 0
|
||||
|
||||
|
||||
def default_retention_days(bucket: str, config: RetentionConfig) -> int:
|
||||
"""Get the default retention days for a bucket from config."""
|
||||
field_name = BUCKET_RETENTION_FIELDS.get(bucket)
|
||||
if field_name:
|
||||
return getattr(config, field_name, 365)
|
||||
return 365
|
||||
|
||||
|
||||
def resolve_policies(config: RetentionConfig) -> list[RetentionPolicy]:
|
||||
"""Build retention policies for all known buckets from config defaults."""
|
||||
return [
|
||||
RetentionPolicy(
|
||||
bucket_name=bucket,
|
||||
retention_days=default_retention_days(bucket, config),
|
||||
)
|
||||
for bucket in ALL_BUCKETS
|
||||
]
|
||||
|
||||
|
||||
async def load_db_policies(pool: asyncpg.Pool) -> dict[str, RetentionPolicy]:
|
||||
"""Load retention policy overrides from the database.
|
||||
|
||||
Returns a dict keyed by bucket_name. DB policies take precedence
|
||||
over config defaults when active.
|
||||
"""
|
||||
rows = await pool.fetch(
|
||||
"""SELECT bucket_name, retention_days, archive_before_delete
|
||||
FROM retention_policies
|
||||
WHERE active = TRUE AND artifact_class = 'default'"""
|
||||
)
|
||||
return {
|
||||
row["bucket_name"]: RetentionPolicy(
|
||||
bucket_name=row["bucket_name"],
|
||||
retention_days=row["retention_days"],
|
||||
archive_before_delete=row["archive_before_delete"],
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
def merge_policies(
|
||||
config_policies: list[RetentionPolicy],
|
||||
db_policies: dict[str, RetentionPolicy],
|
||||
) -> list[RetentionPolicy]:
|
||||
"""Merge config defaults with DB overrides. DB wins on conflict."""
|
||||
merged: list[RetentionPolicy] = []
|
||||
for policy in config_policies:
|
||||
if policy.bucket_name in db_policies:
|
||||
merged.append(db_policies[policy.bucket_name])
|
||||
else:
|
||||
merged.append(policy)
|
||||
return merged
|
||||
|
||||
|
||||
def cutoff_date(retention_days: int, now: datetime | None = None) -> datetime:
|
||||
"""Calculate the cutoff datetime. Objects older than this are expired."""
|
||||
ref = now or datetime.now(timezone.utc)
|
||||
return ref - timedelta(days=retention_days)
|
||||
|
||||
|
||||
def list_expired_objects(
|
||||
client: Minio,
|
||||
bucket: str,
|
||||
retention_days: int,
|
||||
batch_size: int = 1000,
|
||||
now: datetime | None = None,
|
||||
) -> list[str]:
|
||||
"""List object names in a bucket that are older than the retention cutoff.
|
||||
|
||||
Uses the object's last_modified timestamp from MinIO metadata.
|
||||
Returns at most batch_size object names.
|
||||
"""
|
||||
cutoff = cutoff_date(retention_days, now)
|
||||
expired: list[str] = []
|
||||
|
||||
try:
|
||||
objects = client.list_objects(bucket, recursive=True)
|
||||
for obj in objects:
|
||||
if obj.last_modified and obj.last_modified < cutoff:
|
||||
if obj.object_name:
|
||||
expired.append(obj.object_name)
|
||||
if len(expired) >= batch_size:
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Error listing objects in bucket %s", bucket)
|
||||
|
||||
return expired
|
||||
|
||||
|
||||
def delete_expired_objects(
|
||||
client: Minio,
|
||||
bucket: str,
|
||||
object_names: list[str],
|
||||
) -> int:
|
||||
"""Delete a list of objects from a MinIO bucket.
|
||||
|
||||
Returns the count of successfully deleted objects.
|
||||
"""
|
||||
deleted = 0
|
||||
for name in object_names:
|
||||
try:
|
||||
client.remove_object(bucket, name)
|
||||
deleted += 1
|
||||
except Exception:
|
||||
logger.warning("Failed to delete %s/%s", bucket, name, exc_info=True)
|
||||
return deleted
|
||||
|
||||
|
||||
def cleanup_bucket(
|
||||
client: Minio,
|
||||
policy: RetentionPolicy,
|
||||
batch_size: int = 1000,
|
||||
now: datetime | None = None,
|
||||
) -> CleanupResult:
|
||||
"""Run retention cleanup for a single bucket.
|
||||
|
||||
Lists expired objects and deletes them in batches.
|
||||
Returns a CleanupResult with counts.
|
||||
"""
|
||||
result = CleanupResult(bucket_name=policy.bucket_name)
|
||||
|
||||
expired = list_expired_objects(
|
||||
client, policy.bucket_name, policy.retention_days,
|
||||
batch_size=batch_size, now=now,
|
||||
)
|
||||
result.objects_scanned = len(expired)
|
||||
|
||||
if expired:
|
||||
result.objects_deleted = delete_expired_objects(client, policy.bucket_name, expired)
|
||||
logger.info(
|
||||
"Bucket %s: scanned=%d deleted=%d (retention=%dd)",
|
||||
policy.bucket_name, result.objects_scanned,
|
||||
result.objects_deleted, policy.retention_days,
|
||||
)
|
||||
else:
|
||||
logger.debug("Bucket %s: no expired objects (retention=%dd)",
|
||||
policy.bucket_name, policy.retention_days)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# --- PostgreSQL metadata cleanup ---
|
||||
|
||||
# Tables with a created_at or retrieved_at column that should be cleaned up
|
||||
# when the corresponding MinIO artifacts are expired.
|
||||
DB_CLEANUP_QUERIES: list[tuple[str, str]] = [
|
||||
(
|
||||
"ingestion_runs",
|
||||
"DELETE FROM ingestion_runs WHERE started_at < $1",
|
||||
),
|
||||
(
|
||||
"market_snapshots",
|
||||
"DELETE FROM market_snapshots WHERE captured_at < $1",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def cleanup_expired_db_records(
|
||||
pool: asyncpg.Pool,
|
||||
retention_days: int,
|
||||
now: datetime | None = None,
|
||||
) -> int:
|
||||
"""Delete expired operational metadata from PostgreSQL.
|
||||
|
||||
Uses the shortest raw retention period to clean up ingestion tracking
|
||||
and market snapshot records that are past their useful life.
|
||||
|
||||
Returns total rows deleted.
|
||||
"""
|
||||
cutoff = cutoff_date(retention_days, now)
|
||||
total_deleted = 0
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for table_name, query in DB_CLEANUP_QUERIES:
|
||||
try:
|
||||
result = await conn.execute(query, cutoff)
|
||||
# asyncpg returns "DELETE N"
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
total_deleted += count
|
||||
if count > 0:
|
||||
logger.info("Cleaned %d expired rows from %s (cutoff=%s)",
|
||||
count, table_name, cutoff.isoformat())
|
||||
except Exception:
|
||||
logger.exception("Error cleaning table %s", table_name)
|
||||
|
||||
return total_deleted
|
||||
|
||||
|
||||
async def record_retention_run(
|
||||
pool: asyncpg.Pool,
|
||||
bucket_name: str,
|
||||
result: CleanupResult,
|
||||
status: str = "completed",
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
"""Record a retention cleanup run in the retention_runs table."""
|
||||
await pool.execute(
|
||||
"""INSERT INTO retention_runs
|
||||
(bucket_name, objects_scanned, objects_deleted, bytes_freed,
|
||||
db_rows_deleted, completed_at, status, error_message)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW(), $6, $7)""",
|
||||
bucket_name,
|
||||
result.objects_scanned,
|
||||
result.objects_deleted,
|
||||
result.bytes_freed,
|
||||
result.db_rows_deleted,
|
||||
status,
|
||||
error_message,
|
||||
)
|
||||
|
||||
|
||||
async def run_retention_cleanup(
|
||||
minio_client: Minio,
|
||||
pool: asyncpg.Pool,
|
||||
config: RetentionConfig,
|
||||
now: datetime | None = None,
|
||||
) -> list[CleanupResult]:
|
||||
"""Run the full retention cleanup cycle.
|
||||
|
||||
1. Resolve policies from config defaults + DB overrides
|
||||
2. Clean up expired MinIO objects per bucket
|
||||
3. Clean up expired PostgreSQL metadata
|
||||
4. Record each run for observability
|
||||
|
||||
Returns a list of CleanupResult for each bucket processed.
|
||||
"""
|
||||
# Resolve policies
|
||||
config_policies = resolve_policies(config)
|
||||
try:
|
||||
db_policies = await load_db_policies(pool)
|
||||
except Exception:
|
||||
logger.warning("Could not load DB retention policies, using config defaults")
|
||||
db_policies = {}
|
||||
|
||||
policies = merge_policies(config_policies, db_policies)
|
||||
results: list[CleanupResult] = []
|
||||
|
||||
# Clean up MinIO objects per bucket
|
||||
for policy in policies:
|
||||
try:
|
||||
result = cleanup_bucket(
|
||||
minio_client, policy,
|
||||
batch_size=config.batch_size, now=now,
|
||||
)
|
||||
results.append(result)
|
||||
await record_retention_run(pool, policy.bucket_name, result)
|
||||
except Exception:
|
||||
logger.exception("Retention cleanup failed for bucket %s", policy.bucket_name)
|
||||
empty = CleanupResult(bucket_name=policy.bucket_name)
|
||||
await record_retention_run(
|
||||
pool, policy.bucket_name, empty,
|
||||
status="failed", error_message="See logs",
|
||||
)
|
||||
results.append(empty)
|
||||
|
||||
# Clean up expired DB records using the shortest raw retention period
|
||||
min_retention = min(p.retention_days for p in policies)
|
||||
try:
|
||||
db_deleted = await cleanup_expired_db_records(pool, min_retention, now=now)
|
||||
if db_deleted > 0:
|
||||
logger.info("Total DB rows cleaned: %d", db_deleted)
|
||||
except Exception:
|
||||
logger.exception("DB retention cleanup failed")
|
||||
|
||||
return results
|
||||
@@ -108,6 +108,41 @@ class DocumentIntelligence(BaseModel):
|
||||
|
||||
# --- Trend Summary ---
|
||||
|
||||
class MarketContext(BaseModel):
|
||||
"""Recent market data features for a symbol, used to enrich aggregation."""
|
||||
|
||||
ticker: str = ""
|
||||
price_change_pct: Optional[float] = None # % change over the window
|
||||
avg_volume: Optional[float] = None # average daily volume
|
||||
volume_change_pct: Optional[float] = None # volume vs prior period
|
||||
volatility: Optional[float] = None # intra-window price std dev
|
||||
latest_close: Optional[float] = None
|
||||
latest_bar_at: Optional[datetime] = None
|
||||
bars_available: int = 0
|
||||
|
||||
@property
|
||||
def has_data(self) -> bool:
|
||||
return self.bars_available > 0
|
||||
|
||||
|
||||
class DisagreementDetail(BaseModel):
|
||||
"""Represents an explicit disagreement between document signals.
|
||||
|
||||
Rather than collapsing contradictory signals into a single score,
|
||||
this captures the nature of the disagreement so downstream consumers
|
||||
can inspect *why* signals conflict.
|
||||
|
||||
Requirements: 6.4
|
||||
"""
|
||||
|
||||
dimension: str = "" # e.g. "sentiment", "catalyst", "impact_horizon"
|
||||
positive_doc_ids: List[str] = Field(default_factory=list)
|
||||
negative_doc_ids: List[str] = Field(default_factory=list)
|
||||
positive_weight: float = 0.0
|
||||
negative_weight: float = 0.0
|
||||
description: str = ""
|
||||
|
||||
|
||||
class TrendSummary(BaseModel):
|
||||
entity_type: str = "company"
|
||||
entity_id: str = ""
|
||||
@@ -120,6 +155,8 @@ class TrendSummary(BaseModel):
|
||||
dominant_catalysts: List[str] = Field(default_factory=list)
|
||||
material_risks: List[str] = Field(default_factory=list)
|
||||
contradiction_score: float = Field(ge=0, le=1, default=0.0)
|
||||
disagreement_details: List[DisagreementDetail] = Field(default_factory=list)
|
||||
market_context: Optional[MarketContext] = None
|
||||
generated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
"""Raw artifact upload to MinIO.
|
||||
|
||||
Provides a reusable storage layer for uploading raw artifacts (API payloads,
|
||||
HTML, normalized text, model outputs) to MinIO with consistent path conventions,
|
||||
bucket management, and content-type handling.
|
||||
|
||||
Bucket layout follows the design spec:
|
||||
- stonks-raw-market — raw market API payloads
|
||||
- stonks-raw-news — raw news API payloads and article HTML
|
||||
- stonks-raw-filings — raw filings and issuer event payloads
|
||||
- stonks-normalized — cleaned text and parser outputs
|
||||
- stonks-llm-prompts — prompts and schemas used
|
||||
- stonks-llm-results — raw model outputs and validation reports
|
||||
- stonks-lakehouse — partitioned analytical datasets and table metadata
|
||||
- stonks-audit — execution traces and exported reports
|
||||
|
||||
Object path pattern:
|
||||
/{stage}/{symbol}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_type}.{ext}
|
||||
|
||||
Requirements: 3.1, 3.2, 3.3, 9.1
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Mapping
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
|
||||
logger = logging.getLogger("storage")
|
||||
|
||||
# All known buckets the platform uses
|
||||
ALL_BUCKETS = [
|
||||
"stonks-raw-market",
|
||||
"stonks-raw-news",
|
||||
"stonks-raw-filings",
|
||||
"stonks-normalized",
|
||||
"stonks-llm-prompts",
|
||||
"stonks-llm-results",
|
||||
"stonks-lakehouse",
|
||||
"stonks-audit",
|
||||
]
|
||||
|
||||
# Map source_type to the correct raw bucket
|
||||
SOURCE_BUCKET_MAP: dict[str, str] = {
|
||||
"market_api": "stonks-raw-market",
|
||||
"news_api": "stonks-raw-news",
|
||||
"filings_api": "stonks-raw-filings",
|
||||
"web_scrape": "stonks-raw-news",
|
||||
"broker": "stonks-raw-market",
|
||||
}
|
||||
|
||||
# Map artifact type to content type and file extension
|
||||
ARTIFACT_CONTENT_TYPES: dict[str, tuple[str, str]] = {
|
||||
"raw_json": ("application/json", "json"),
|
||||
"raw_html": ("text/html", "html"),
|
||||
"raw_text": ("text/plain", "txt"),
|
||||
"raw_payload": ("application/octet-stream", "bin"),
|
||||
}
|
||||
|
||||
|
||||
def bucket_for_source(source_type: str) -> str:
|
||||
"""Return the MinIO bucket name for a given source type."""
|
||||
return SOURCE_BUCKET_MAP.get(source_type, "stonks-raw-market")
|
||||
|
||||
|
||||
def build_artifact_path(
|
||||
source_type: str,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
artifact_name: str = "raw",
|
||||
ext: str = "json",
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Build a MinIO object path following the design convention.
|
||||
|
||||
Pattern: {source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext}
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
return (
|
||||
f"{source_type}/{ticker}/"
|
||||
f"{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/{artifact_name}.{ext}"
|
||||
)
|
||||
|
||||
|
||||
def storage_ref(bucket: str, path: str) -> str:
|
||||
"""Build an s3:// URI for a stored artifact."""
|
||||
return f"s3://{bucket}/{path}"
|
||||
|
||||
|
||||
def ensure_buckets(client: Minio, buckets: list[str] | None = None) -> list[str]:
|
||||
"""Create any missing buckets. Returns list of buckets that were created."""
|
||||
target_buckets = buckets or ALL_BUCKETS
|
||||
created: list[str] = []
|
||||
for bucket in target_buckets:
|
||||
try:
|
||||
if not client.bucket_exists(bucket):
|
||||
client.make_bucket(bucket)
|
||||
created.append(bucket)
|
||||
logger.info("Created bucket: %s", bucket)
|
||||
except S3Error as e:
|
||||
logger.error("Failed to ensure bucket %s: %s", bucket, e)
|
||||
raise
|
||||
return created
|
||||
|
||||
|
||||
def upload_artifact(
|
||||
client: Minio,
|
||||
bucket: str,
|
||||
path: str,
|
||||
data: bytes,
|
||||
content_type: str = "application/json",
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload raw bytes to MinIO and return the s3:// storage reference.
|
||||
|
||||
Args:
|
||||
client: MinIO client instance.
|
||||
bucket: Target bucket name.
|
||||
path: Object path within the bucket.
|
||||
data: Raw bytes to upload.
|
||||
content_type: MIME type for the object.
|
||||
metadata: Optional user metadata to attach to the object.
|
||||
|
||||
Returns:
|
||||
s3:// URI pointing to the uploaded object.
|
||||
"""
|
||||
_result = client.put_object(
|
||||
bucket,
|
||||
path,
|
||||
io.BytesIO(data),
|
||||
length=len(data),
|
||||
content_type=content_type,
|
||||
metadata=metadata,
|
||||
)
|
||||
ref = storage_ref(bucket, path)
|
||||
logger.debug("Uploaded %d bytes to %s", len(data), ref)
|
||||
return ref
|
||||
|
||||
|
||||
def upload_raw_artifact(
|
||||
client: Minio,
|
||||
source_type: str,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
data: bytes,
|
||||
artifact_type: str = "raw_json",
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload a raw artifact using standard conventions for bucket, path, and content type.
|
||||
|
||||
This is the primary entry point for ingestion workers to store raw payloads.
|
||||
|
||||
Args:
|
||||
client: MinIO client instance.
|
||||
source_type: One of market_api, news_api, filings_api, web_scrape, broker.
|
||||
ticker: Company ticker symbol.
|
||||
document_id: Unique document or run identifier.
|
||||
data: Raw bytes to upload.
|
||||
artifact_type: One of raw_json, raw_html, raw_text, raw_payload.
|
||||
timestamp: Override timestamp for path generation (defaults to now UTC).
|
||||
metadata: Optional user metadata dict.
|
||||
|
||||
Returns:
|
||||
s3:// URI pointing to the uploaded object.
|
||||
"""
|
||||
bucket = bucket_for_source(source_type)
|
||||
ct, ext = ARTIFACT_CONTENT_TYPES.get(artifact_type, ("application/octet-stream", "bin"))
|
||||
path = build_artifact_path(
|
||||
source_type=source_type,
|
||||
ticker=ticker,
|
||||
document_id=document_id,
|
||||
artifact_name="raw",
|
||||
ext=ext,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
return upload_artifact(client, bucket, path, data, content_type=ct, metadata=metadata)
|
||||
|
||||
|
||||
def upload_html_artifact(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
html_bytes: bytes,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload raw HTML for a scraped web page.
|
||||
|
||||
Stores in stonks-raw-news under the web_scrape source path.
|
||||
"""
|
||||
bucket = bucket_for_source("web_scrape")
|
||||
path = build_artifact_path(
|
||||
source_type="web_scrape",
|
||||
ticker=ticker,
|
||||
document_id=document_id,
|
||||
artifact_name="raw",
|
||||
ext="html",
|
||||
timestamp=timestamp,
|
||||
)
|
||||
return upload_artifact(client, bucket, path, html_bytes, content_type="text/html", metadata=metadata)
|
||||
|
||||
|
||||
def upload_normalized_text(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
text_bytes: bytes,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload normalized (parsed) text to the stonks-normalized bucket.
|
||||
|
||||
Stores under parsed/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/normalized.txt
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
path = (
|
||||
f"parsed/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/normalized.txt"
|
||||
)
|
||||
return upload_artifact(
|
||||
client, "stonks-normalized", path, text_bytes,
|
||||
content_type="text/plain", metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def upload_parser_output(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
output_bytes: bytes,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload structured parser output JSON to the stonks-normalized bucket.
|
||||
|
||||
Stores under parsed/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/parser_output.json
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
path = (
|
||||
f"parsed/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/parser_output.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
client, "stonks-normalized", path, output_bytes,
|
||||
content_type="application/json", metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def upload_extraction_prompt(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
prompt_data: bytes,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload the extraction prompt and schema to stonks-llm-prompts.
|
||||
|
||||
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/prompt.json
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
path = (
|
||||
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/prompt.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
client, "stonks-llm-prompts", path, prompt_data,
|
||||
content_type="application/json", metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def upload_extraction_raw_output(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
output_data: bytes,
|
||||
attempt_index: int = 0,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload a raw model output to stonks-llm-results.
|
||||
|
||||
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/raw_output_{attempt}.json
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
path = (
|
||||
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/raw_output_{attempt_index}.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
client, "stonks-llm-results", path, output_data,
|
||||
content_type="application/json", metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def upload_extraction_validation(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
validation_data: bytes,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload a validation report to stonks-llm-results.
|
||||
|
||||
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/validation.json
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
path = (
|
||||
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/validation.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
client, "stonks-llm-results", path, validation_data,
|
||||
content_type="application/json", metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def upload_extraction_intelligence(
|
||||
client: Minio,
|
||||
ticker: str,
|
||||
document_id: str,
|
||||
intelligence_data: bytes,
|
||||
timestamp: datetime | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Upload the final intelligence object to stonks-llm-results.
|
||||
|
||||
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/intelligence.json
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
path = (
|
||||
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/intelligence.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
client, "stonks-llm-results", path, intelligence_data,
|
||||
content_type="application/json", metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def download_artifact(client: Minio, bucket: str, path: str) -> bytes:
|
||||
"""Download an artifact from MinIO and return its bytes."""
|
||||
response = client.get_object(bucket, path)
|
||||
try:
|
||||
return response.read()
|
||||
finally:
|
||||
response.close()
|
||||
response.release_conn()
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel, field_validator
|
||||
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_pg_pool
|
||||
from services.shared.logging import setup_logging
|
||||
|
||||
config = load_config()
|
||||
pool: Optional[asyncpg.Pool] = None
|
||||
@@ -18,6 +19,7 @@ pool: Optional[asyncpg.Pool] = None
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global pool
|
||||
setup_logging("symbol_registry", level=config.log_level, json_output=config.json_logs)
|
||||
pool = await get_pg_pool(config)
|
||||
yield
|
||||
await pool.close()
|
||||
|
||||
@@ -13,8 +13,8 @@ import asyncpg
|
||||
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_pg_pool
|
||||
from services.shared.logging import setup_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("seed")
|
||||
|
||||
# --- Seed Companies ---
|
||||
@@ -173,6 +173,7 @@ async def seed(pool: asyncpg.Pool) -> None:
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("seed", level=config.log_level, json_output=config.json_logs)
|
||||
pool = await get_pg_pool(config)
|
||||
try:
|
||||
await seed(pool)
|
||||
|
||||
Reference in New Issue
Block a user