phase 14-15: docker build validation and helm deployment

This commit is contained in:
Celes Renata
2026-04-11 11:59:45 -07:00
parent 7394d241c9
commit ce10afa034
179 changed files with 32559 additions and 576 deletions
+44
View File
@@ -1 +1,45 @@
# Ingestion Adapters
from .base import AdapterResult, BaseAdapter
from .resilient import ResilientAdapter, RetryConfig, RetryStats, compute_delay
from .broker_adapter import (
AccountInfo,
AlpacaBrokerAdapter,
BrokerDataAdapter,
OrderEventType,
OrderRequest,
OrderResponse,
OrderSide,
OrderStatus,
OrderType,
PositionInfo,
TradingMode,
)
from .filings_adapter import FilingsDataAdapter, SECEdgarAdapter
from .market_adapter import MarketDataAdapter, PolygonMarketAdapter
from .news_adapter import NewsDataAdapter, PolygonNewsAdapter
__all__ = [
"AccountInfo",
"AdapterResult",
"AlpacaBrokerAdapter",
"BaseAdapter",
"BrokerDataAdapter",
"FilingsDataAdapter",
"MarketDataAdapter",
"NewsDataAdapter",
"OrderEventType",
"OrderRequest",
"OrderResponse",
"OrderSide",
"OrderStatus",
"OrderType",
"PolygonMarketAdapter",
"PolygonNewsAdapter",
"PositionInfo",
"ResilientAdapter",
"RetryConfig",
"RetryStats",
"SECEdgarAdapter",
"TradingMode",
"compute_delay",
]
+63 -8
View File
@@ -1,29 +1,84 @@
"""Base adapter interface for all external API integrations."""
"""Base adapter interface for all external API integrations.
All ingestion adapters follow the same contract:
1. Fetch external payloads for a given ticker/source config.
2. Return a structured result with raw bytes, parsed items, and metadata.
3. The ingestion worker handles MinIO upload, PostgreSQL metadata, and downstream job emission.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5, 3.1, 3.2, 3.3, 3.4
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
@dataclass
class AdapterResult:
"""Result of a single adapter fetch operation."""
source_type: str
ticker: str
items: List[Dict[str, Any]]
items: list[dict[str, Any]]
raw_payload: bytes
content_hash: str
fetched_at: datetime
error: Optional[str] = None
error: str | None = None
# HTTP metadata for observability
http_status: int | None = None
response_time_ms: float | None = None
# Additional metadata the adapter wants to pass downstream
metadata: dict[str, Any] = field(default_factory=dict)
@property
def ok(self) -> bool:
"""True if the fetch succeeded without error."""
return self.error is None and len(self.items) > 0
@property
def item_count(self) -> int:
return len(self.items)
class BaseAdapter(ABC):
"""Interface for all ingestion adapters."""
"""Interface for all ingestion adapters.
Subclasses implement fetch() for their specific API and source_type()
to identify the adapter class. The ingestion worker orchestrates
persistence and downstream job emission.
"""
@abstractmethod
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
"""Fetch data for a given ticker using source config."""
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch data for a given ticker using source config.
Args:
ticker: The company ticker symbol.
config: Source-specific configuration from the sources table.
Returns:
AdapterResult with raw payload, parsed items, and metadata.
"""
...
@abstractmethod
def source_type(self) -> str:
"""Return the source type identifier for this adapter (e.g. 'market_api')."""
...
def bucket_name(self) -> str:
"""Return the MinIO bucket name for raw artifact storage.
Override in subclasses if the bucket differs from the default pattern.
"""
return f"stonks-raw-{self.source_type().replace('_api', '').replace('_', '-')}"
def artifact_path(self, ticker: str, document_id: str, now: datetime) -> str:
"""Build the MinIO object path for a raw artifact.
Pattern: /{source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/raw.json
"""
return (
f"{self.source_type()}/{ticker}/"
f"{now.strftime('%Y/%m/%d')}/{document_id}/raw.json"
)
+558 -61
View File
@@ -1,9 +1,19 @@
"""Broker API adapter - paper/live trading, orders, positions, balances."""
"""Broker API adapter interface for paper trading and order events.
The BrokerDataAdapter is the abstract interface for all broker integrations.
AlpacaBrokerAdapter is the first concrete implementation, targeting the
Alpaca Markets REST API for paper and live trading.
Requirements: 2.4, 2.5, 8.1, 8.3, 8.5
"""
import hashlib
import logging
import time
import uuid
from datetime import datetime
from typing import Any, Dict, Optional
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any
import httpx
@@ -12,97 +22,584 @@ from .base import AdapterResult, BaseAdapter
logger = logging.getLogger("broker_adapter")
class BrokerAdapter(BaseAdapter):
"""Broker API adapter supporting paper and live modes."""
# --- Broker-specific enums ---
def __init__(self, api_key: str = "", api_secret: str = "", base_url: str = "", mode: str = "paper"):
self.api_key = api_key
self.api_secret = api_secret
self.base_url = base_url
self.mode = mode # paper | live
class OrderSide(str, Enum):
BUY = "buy"
SELL = "sell"
class OrderType(str, Enum):
MARKET = "market"
LIMIT = "limit"
STOP = "stop"
STOP_LIMIT = "stop_limit"
class OrderStatus(str, Enum):
PENDING = "pending"
SUBMITTED = "submitted"
ACCEPTED = "accepted"
PARTIALLY_FILLED = "partially_filled"
FILLED = "filled"
CANCELLED = "cancelled"
REJECTED = "rejected"
EXPIRED = "expired"
class TradingMode(str, Enum):
PAPER = "paper"
LIVE = "live"
class OrderEventType(str, Enum):
SUBMITTED = "submitted"
ACCEPTED = "accepted"
REJECTED = "rejected"
FILL = "fill"
PARTIAL_FILL = "partial_fill"
CANCELLED = "cancelled"
EXPIRED = "expired"
# --- Data structures ---
class OrderRequest:
"""Represents an order to be submitted to a broker."""
def __init__(
self,
ticker: str,
side: OrderSide,
quantity: float,
order_type: OrderType = OrderType.MARKET,
limit_price: float | None = None,
stop_price: float | None = None,
time_in_force: str = "day",
idempotency_key: str | None = None,
) -> None:
self.ticker = ticker
self.side = side
self.quantity = quantity
self.order_type = order_type
self.limit_price = limit_price
self.stop_price = stop_price
self.time_in_force = time_in_force
self.idempotency_key = idempotency_key or str(uuid.uuid4())
def to_dict(self) -> dict[str, Any]:
"""Serialize to a dict for audit/persistence."""
d: dict[str, Any] = {
"ticker": self.ticker,
"side": self.side.value,
"quantity": self.quantity,
"order_type": self.order_type.value,
"time_in_force": self.time_in_force,
"idempotency_key": self.idempotency_key,
}
if self.limit_price is not None:
d["limit_price"] = self.limit_price
if self.stop_price is not None:
d["stop_price"] = self.stop_price
return d
class OrderResponse:
"""Represents a broker's response to an order submission."""
def __init__(
self,
broker_order_id: str,
status: OrderStatus,
ticker: str,
side: OrderSide,
quantity: float,
filled_quantity: float = 0.0,
filled_avg_price: float | None = None,
submitted_at: datetime | None = None,
raw_response: dict[str, Any] | None = None,
error: str | None = None,
) -> None:
self.broker_order_id = broker_order_id
self.status = status
self.ticker = ticker
self.side = side
self.quantity = quantity
self.filled_quantity = filled_quantity
self.filled_avg_price = filled_avg_price
self.submitted_at = submitted_at or datetime.now(timezone.utc)
self.raw_response = raw_response or {}
self.error = error
@property
def ok(self) -> bool:
return self.error is None and self.status not in (
OrderStatus.REJECTED,
OrderStatus.CANCELLED,
OrderStatus.EXPIRED,
)
def to_dict(self) -> dict[str, Any]:
return {
"broker_order_id": self.broker_order_id,
"status": self.status.value,
"ticker": self.ticker,
"side": self.side.value,
"quantity": self.quantity,
"filled_quantity": self.filled_quantity,
"filled_avg_price": self.filled_avg_price,
"submitted_at": self.submitted_at.isoformat(),
"error": self.error,
}
class PositionInfo:
"""Represents a current position from the broker."""
def __init__(
self,
ticker: str,
quantity: float,
avg_entry_price: float,
current_price: float,
unrealized_pnl: float,
market_value: float,
side: str = "long",
) -> None:
self.ticker = ticker
self.quantity = quantity
self.avg_entry_price = avg_entry_price
self.current_price = current_price
self.unrealized_pnl = unrealized_pnl
self.market_value = market_value
self.side = side
def to_dict(self) -> dict[str, Any]:
return {
"ticker": self.ticker,
"quantity": self.quantity,
"avg_entry_price": self.avg_entry_price,
"current_price": self.current_price,
"unrealized_pnl": self.unrealized_pnl,
"market_value": self.market_value,
"side": self.side,
}
class AccountInfo:
"""Represents broker account summary."""
def __init__(
self,
account_id: str,
buying_power: float,
cash: float,
portfolio_value: float,
currency: str = "USD",
mode: TradingMode = TradingMode.PAPER,
) -> None:
self.account_id = account_id
self.buying_power = buying_power
self.cash = cash
self.portfolio_value = portfolio_value
self.currency = currency
self.mode = mode
def to_dict(self) -> dict[str, Any]:
return {
"account_id": self.account_id,
"buying_power": self.buying_power,
"cash": self.cash,
"portfolio_value": self.portfolio_value,
"currency": self.currency,
"mode": self.mode.value,
}
# --- Abstract interface ---
class BrokerDataAdapter(BaseAdapter, ABC):
"""Abstract interface for broker API integrations.
Extends BaseAdapter with broker-specific operations:
- submit_order: place an order with idempotency key
- cancel_order: cancel an existing order
- get_order_status: check order state
- get_positions: list current positions
- get_account: retrieve account summary
All concrete adapters must enforce:
- Idempotent order submission via idempotency_key (Req 8.5)
- Paper/live mode separation (Req 8.1)
- Fail-closed on broker unavailability (Req 8.5)
"""
def __init__(self, mode: TradingMode = TradingMode.PAPER) -> None:
self._mode = mode
@property
def mode(self) -> TradingMode:
return self._mode
def source_type(self) -> str:
return "broker"
def _headers(self) -> Dict[str, str]:
@abstractmethod
async def submit_order(self, order: OrderRequest) -> OrderResponse:
"""Submit an order to the broker.
Must use order.idempotency_key to prevent duplicate submissions.
Must fail closed if the broker is unavailable or returns ambiguous state.
"""
...
@abstractmethod
async def cancel_order(self, broker_order_id: str) -> OrderResponse:
"""Cancel an existing order by broker order ID."""
...
@abstractmethod
async def get_order_status(self, broker_order_id: str) -> OrderResponse:
"""Get the current status of an order."""
...
@abstractmethod
async def get_positions(self) -> list[PositionInfo]:
"""Get all current positions."""
...
@abstractmethod
async def get_account(self) -> AccountInfo:
"""Get account summary (balance, buying power, etc.)."""
...
# --- Concrete Alpaca implementation ---
class AlpacaBrokerAdapter(BrokerDataAdapter):
"""Concrete broker adapter for the Alpaca Markets REST API.
Supports:
- Paper trading via paper-api.alpaca.markets
- Live trading via api.alpaca.markets
- Order submission, cancellation, and status
- Position and account queries
Config options for fetch():
endpoint: One of "positions", "orders", "account" (default "positions")
"""
PAPER_BASE_URL: str = "https://paper-api.alpaca.markets"
LIVE_BASE_URL: str = "https://api.alpaca.markets"
def __init__(
self,
api_key: str,
api_secret: str,
mode: TradingMode = TradingMode.PAPER,
base_url: str | None = None,
) -> None:
super().__init__(mode=mode)
self.api_key = api_key
self.api_secret = api_secret
if base_url:
self.base_url = base_url.rstrip("/")
elif mode == TradingMode.LIVE:
self.base_url = self.LIVE_BASE_URL
else:
self.base_url = self.PAPER_BASE_URL
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"APCA-API-KEY-ID": self.api_key,
"APCA-API-SECRET-KEY": self.api_secret,
"Content-Type": "application/json",
}
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
"""Fetch positions and recent orders for a ticker."""
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch positions or recent orders for a ticker from Alpaca.
This satisfies the BaseAdapter contract for the ingestion pipeline.
The broker adapter uses fetch() to pull position/order snapshots
that get persisted as raw artifacts.
"""
endpoint = config.get("endpoint", "positions")
url = self._build_fetch_url(ticker, endpoint)
async with httpx.AsyncClient(timeout=30) as client:
t0 = time.monotonic()
try:
resp = await client.get(
f"{self.base_url}/v2/positions/{ticker}",
headers=self._headers(),
)
resp = await client.get(url, headers=self._headers())
elapsed_ms = (time.monotonic() - t0) * 1000
resp.raise_for_status()
raw = resp.content
data = resp.json() if resp.status_code == 200 else {}
data = resp.json()
content_hash = hashlib.sha256(raw).hexdigest()
items = [data] if isinstance(data, dict) else data if isinstance(data, list) else []
return AdapterResult(
source_type="broker",
ticker=ticker,
items=[data] if data else [],
items=items,
raw_payload=raw,
content_hash=content_hash,
fetched_at=datetime.utcnow(),
fetched_at=datetime.now(timezone.utc),
http_status=resp.status_code,
response_time_ms=round(elapsed_ms, 1),
metadata={
"provider": "alpaca",
"mode": self._mode.value,
"endpoint": endpoint,
},
)
except httpx.HTTPStatusError as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Alpaca HTTP error for %s: %s", ticker, e)
return self._error_result(
ticker, str(e), elapsed_ms,
http_status=e.response.status_code if e.response else None,
raw=e.response.content if e.response else b"",
)
except Exception as e:
logger.error(f"Broker fetch failed for {ticker}: {e}")
return AdapterResult(
source_type="broker",
ticker=ticker,
items=[],
raw_payload=b"",
content_hash="",
fetched_at=datetime.utcnow(),
error=str(e),
)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Alpaca fetch failed for %s: %s", ticker, e)
return self._error_result(ticker, str(e), elapsed_ms)
async def submit_order(
self,
ticker: str,
side: str,
qty: float,
order_type: str = "market",
limit_price: Optional[float] = None,
idempotency_key: Optional[str] = None,
) -> Dict[str, Any]:
"""Submit an order to the broker. Returns broker response."""
if self.mode == "live":
logger.warning("LIVE order submission")
def _build_fetch_url(self, ticker: str, endpoint: str) -> str:
"""Build the URL for a fetch operation."""
if endpoint == "orders":
return f"{self.base_url}/v2/orders?symbols={ticker}&status=all&limit=50"
if endpoint == "account":
return f"{self.base_url}/v2/account"
# Default: positions for ticker
return f"{self.base_url}/v2/positions/{ticker}"
idem_key = idempotency_key or str(uuid.uuid4())
payload = {
"symbol": ticker,
"qty": str(qty),
"side": side,
"type": order_type,
"time_in_force": "day",
async def submit_order(self, order: OrderRequest) -> OrderResponse:
"""Submit an order to Alpaca with idempotency key.
Fails closed: any network error or ambiguous response returns
a rejected OrderResponse rather than risking duplicate orders.
"""
if self._mode == TradingMode.LIVE:
logger.warning("LIVE order submission: %s %s %s", order.side.value, order.quantity, order.ticker)
payload: dict[str, Any] = {
"symbol": order.ticker,
"qty": str(order.quantity),
"side": order.side.value,
"type": order.order_type.value,
"time_in_force": order.time_in_force,
}
if limit_price and order_type == "limit":
payload["limit_price"] = str(limit_price)
if order.limit_price is not None and order.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT):
payload["limit_price"] = str(order.limit_price)
if order.stop_price is not None and order.order_type in (OrderType.STOP, OrderType.STOP_LIMIT):
payload["stop_price"] = str(order.stop_price)
headers = {**self._headers(), "Idempotency-Key": order.idempotency_key}
async with httpx.AsyncClient(timeout=30) as client:
try:
resp = await client.post(
f"{self.base_url}/v2/orders",
headers={**self._headers(), "Idempotency-Key": idem_key},
headers=headers,
json=payload,
)
resp.raise_for_status()
return resp.json()
data = resp.json()
return self._parse_order_response(data)
except httpx.HTTPStatusError as e:
logger.error(f"Order rejected: {e.response.text}")
return {"error": e.response.text, "status": e.response.status_code}
error_body = e.response.text if e.response else "unknown"
logger.error("Order rejected by Alpaca: %s", error_body)
return OrderResponse(
broker_order_id="",
status=OrderStatus.REJECTED,
ticker=order.ticker,
side=order.side,
quantity=order.quantity,
error=f"HTTP {e.response.status_code}: {error_body}" if e.response else str(e),
raw_response={"error": error_body},
)
except Exception as e:
logger.error(f"Order submission failed: {e}")
return {"error": str(e)}
# Fail closed: treat any unexpected error as rejection
logger.error("Order submission failed (fail-closed): %s", e)
return OrderResponse(
broker_order_id="",
status=OrderStatus.REJECTED,
ticker=order.ticker,
side=order.side,
quantity=order.quantity,
error=f"fail-closed: {e}",
)
async def get_account(self) -> Dict[str, Any]:
async def cancel_order(self, broker_order_id: str) -> OrderResponse:
"""Cancel an order on Alpaca."""
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.get(f"{self.base_url}/v2/account", headers=self._headers())
return resp.json()
try:
resp = await client.delete(
f"{self.base_url}/v2/orders/{broker_order_id}",
headers=self._headers(),
)
if resp.status_code == 204:
return OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.CANCELLED,
ticker="",
side=OrderSide.BUY,
quantity=0,
)
resp.raise_for_status()
data = resp.json()
return self._parse_order_response(data)
except Exception as e:
logger.error("Cancel failed for %s: %s", broker_order_id, e)
return OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.REJECTED,
ticker="",
side=OrderSide.BUY,
quantity=0,
error=str(e),
)
async def get_order_status(self, broker_order_id: str) -> OrderResponse:
"""Get order status from Alpaca."""
async with httpx.AsyncClient(timeout=30) as client:
try:
resp = await client.get(
f"{self.base_url}/v2/orders/{broker_order_id}",
headers=self._headers(),
)
resp.raise_for_status()
data = resp.json()
return self._parse_order_response(data)
except Exception as e:
logger.error("Get order status failed for %s: %s", broker_order_id, e)
return OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.REJECTED,
ticker="",
side=OrderSide.BUY,
quantity=0,
error=str(e),
)
async def get_positions(self) -> list[PositionInfo]:
"""Get all current positions from Alpaca."""
async with httpx.AsyncClient(timeout=30) as client:
try:
resp = await client.get(
f"{self.base_url}/v2/positions",
headers=self._headers(),
)
resp.raise_for_status()
data = resp.json()
if not isinstance(data, list):
return []
return [self._parse_position(p) for p in data if isinstance(p, dict)]
except Exception as e:
logger.error("Get positions failed: %s", e)
return []
async def get_account(self) -> AccountInfo:
"""Get account summary from Alpaca."""
async with httpx.AsyncClient(timeout=30) as client:
try:
resp = await client.get(
f"{self.base_url}/v2/account",
headers=self._headers(),
)
resp.raise_for_status()
data = resp.json()
return AccountInfo(
account_id=str(data.get("id", "")),
buying_power=float(data.get("buying_power", 0)),
cash=float(data.get("cash", 0)),
portfolio_value=float(data.get("portfolio_value", 0)),
currency=str(data.get("currency", "USD")),
mode=self._mode,
)
except Exception as e:
logger.error("Get account failed: %s", e)
return AccountInfo(
account_id="",
buying_power=0,
cash=0,
portfolio_value=0,
mode=self._mode,
)
def _parse_order_response(self, data: dict[str, Any]) -> OrderResponse:
"""Parse an Alpaca order response into an OrderResponse."""
status_map: dict[str, OrderStatus] = {
"new": OrderStatus.SUBMITTED,
"accepted": OrderStatus.ACCEPTED,
"partially_filled": OrderStatus.PARTIALLY_FILLED,
"filled": OrderStatus.FILLED,
"done_for_day": OrderStatus.FILLED,
"canceled": OrderStatus.CANCELLED,
"expired": OrderStatus.EXPIRED,
"replaced": OrderStatus.SUBMITTED,
"pending_new": OrderStatus.PENDING,
"pending_cancel": OrderStatus.PENDING,
"pending_replace": OrderStatus.PENDING,
"rejected": OrderStatus.REJECTED,
}
raw_status = str(data.get("status", "pending"))
status = status_map.get(raw_status, OrderStatus.PENDING)
side_str = str(data.get("side", "buy"))
side = OrderSide.SELL if side_str == "sell" else OrderSide.BUY
filled_qty = float(data.get("filled_qty", 0) or 0)
filled_avg = data.get("filled_avg_price")
filled_avg_price = float(filled_avg) if filled_avg else None
return OrderResponse(
broker_order_id=str(data.get("id", "")),
status=status,
ticker=str(data.get("symbol", "")),
side=side,
quantity=float(data.get("qty", 0) or 0),
filled_quantity=filled_qty,
filled_avg_price=filled_avg_price,
raw_response=data,
)
def _parse_position(self, data: dict[str, Any]) -> PositionInfo:
"""Parse an Alpaca position response into a PositionInfo."""
return PositionInfo(
ticker=str(data.get("symbol", "")),
quantity=float(data.get("qty", 0) or 0),
avg_entry_price=float(data.get("avg_entry_price", 0) or 0),
current_price=float(data.get("current_price", 0) or 0),
unrealized_pnl=float(data.get("unrealized_pl", 0) or 0),
market_value=float(data.get("market_value", 0) or 0),
side=str(data.get("side", "long")),
)
def _error_result(
self,
ticker: str,
error: str,
elapsed_ms: float,
http_status: int | None = None,
raw: bytes = b"",
) -> AdapterResult:
"""Build an error AdapterResult for broker fetches."""
return AdapterResult(
source_type="broker",
ticker=ticker,
items=[],
raw_payload=raw,
content_hash="",
fetched_at=datetime.now(timezone.utc),
error=error,
http_status=http_status,
response_time_ms=round(elapsed_ms, 1),
metadata={"provider": "alpaca", "mode": self._mode.value},
)
+832
View File
@@ -0,0 +1,832 @@
"""Broker adapter service - standalone worker for sandbox order execution.
Runs the Alpaca broker adapter in sandbox (paper) mode, processing order
requests from the broker queue, evaluating them through the risk engine,
submitting to Alpaca's paper trading API, and persisting the full audit trail.
Also periodically syncs positions and account state from Alpaca.
Implements idempotent order submission keys and duplicate prevention:
- Deterministic idempotency key generation from job attributes
- Redis-based fast-path duplicate detection before broker submission
- PostgreSQL UNIQUE constraint on idempotency_key as durable fallback
Requirements: 2.4, 8.1, 8.3, 8.5
Design: Section 4.9 - Broker Adapter
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
import asyncpg
import redis.asyncio as aioredis
from services.adapters.broker_adapter import (
AlpacaBrokerAdapter,
OrderRequest,
OrderResponse,
OrderSide,
OrderStatus,
OrderType,
TradingMode,
)
from services.risk.engine import (
AccountRiskState,
PortfolioRiskConfig,
ProposedOrder,
evaluate_order,
)
from services.risk.approval import (
ApprovalRequest,
ApprovalStatus,
compute_expiry,
create_approval_request,
requires_approval,
)
from services.shared.audit import (
audit_approval_requested,
audit_duplicate_prevented,
audit_order_filled,
audit_order_rejected,
audit_order_submitted,
audit_risk_evaluated,
)
from services.lake_publisher.worker import (
publish_trade_order,
publish_trade_fill,
publish_positions_daily_batch,
LAKEHOUSE_BUCKET,
)
from services.shared.config import load_config
from services.shared.db import get_pg_pool, get_redis
from services.shared.logging import Span, new_trace_id, set_trace_context, setup_logging
from services.shared.metrics import (
ORDERS_DUPLICATES_PREVENTED,
ORDERS_FILLED,
ORDERS_REJECTED,
ORDERS_SUBMITTED,
POSITIONS_SYNCED,
RISK_CHECK_FAILURES,
RISK_EVALUATIONS_TOTAL,
)
from services.shared.redis_keys import QUEUE_BROKER, queue_key
logger = logging.getLogger("broker_service")
POSITION_SYNC_INTERVAL = 60 # seconds
# Redis TTL for idempotency markers (24 hours)
ORDER_IDEMPOTENCY_TTL = 86400
ORDER_IDEMPOTENCY_PREFIX = "stonks:order_idempotency"
# ---------------------------------------------------------------------------
# DB persistence helpers
# ---------------------------------------------------------------------------
_UPSERT_BROKER_ACCOUNT = """
INSERT INTO broker_accounts (id, provider, account_id, mode, config, active)
VALUES ($1::uuid, $2, $3, $4, $5::jsonb, TRUE)
ON CONFLICT (id) DO UPDATE SET
config = EXCLUDED.config,
mode = EXCLUDED.mode,
active = TRUE
"""
_INSERT_ORDER = """
INSERT INTO orders (
id, recommendation_id, broker_account_id, ticker, side, order_type,
quantity, limit_price, stop_price, status, idempotency_key,
broker_order_id, decision_trace, submitted_at, filled_at,
fill_price, fill_quantity
) VALUES (
$1::uuid, $2, $3::uuid, $4, $5, $6,
$7, $8, $9, $10, $11,
$12, $13::jsonb, $14, $15,
$16, $17
)
ON CONFLICT (idempotency_key) DO UPDATE SET
status = EXCLUDED.status,
broker_order_id = EXCLUDED.broker_order_id,
filled_at = EXCLUDED.filled_at,
fill_price = EXCLUDED.fill_price,
fill_quantity = EXCLUDED.fill_quantity,
updated_at = NOW()
"""
_INSERT_ORDER_EVENT = """
INSERT INTO order_events (order_id, event_type, data, broker_timestamp)
VALUES ($1::uuid, $2, $3::jsonb, $4)
"""
_INSERT_RISK_EVALUATION = """
INSERT INTO risk_evaluations (id, recommendation_id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at)
VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb, $6::jsonb, $7)
"""
_UPSERT_POSITION = """
INSERT INTO positions (broker_account_id, ticker, quantity, avg_entry_price, current_price, unrealized_pnl, updated_at)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7)
ON CONFLICT (broker_account_id, ticker)
DO UPDATE SET
quantity = EXCLUDED.quantity,
avg_entry_price = EXCLUDED.avg_entry_price,
current_price = EXCLUDED.current_price,
unrealized_pnl = EXCLUDED.unrealized_pnl,
updated_at = EXCLUDED.updated_at
"""
_LOAD_RISK_CONFIG = """
SELECT config FROM risk_configs WHERE active = TRUE ORDER BY updated_at DESC LIMIT 1
"""
_LOAD_DAILY_SNAPSHOT = """
SELECT portfolio_value, daily_pnl, daily_trade_count, positions_by_sector
FROM daily_risk_snapshots
WHERE account_id = $1 AND snapshot_date = CURRENT_DATE
LIMIT 1
"""
_CHECK_ORDER_BY_IDEMPOTENCY_KEY = """
SELECT id, status, broker_order_id FROM orders
WHERE idempotency_key = $1
LIMIT 1
"""
# ---------------------------------------------------------------------------
# Idempotency helpers (Requirement 8.5)
# ---------------------------------------------------------------------------
def generate_idempotency_key(job: dict[str, Any]) -> str:
"""Generate a deterministic idempotency key from job attributes.
If the job already carries an explicit idempotency_key, use it.
Otherwise, derive a stable key from the combination of
recommendation_id, ticker, side, quantity, and order_type so that
replayed queue messages produce the same key and are detected as
duplicates.
"""
explicit = job.get("idempotency_key")
if explicit:
return str(explicit)
# Build a deterministic key from job content
parts = [
str(job.get("recommendation_id", "")),
str(job.get("ticker", "")),
str(job.get("side", "buy")),
str(job.get("quantity", 0)),
str(job.get("order_type", "market")),
str(job.get("limit_price", "")),
str(job.get("stop_price", "")),
]
raw = "|".join(parts)
return hashlib.sha256(raw.encode()).hexdigest()[:40]
def _redis_idempotency_key(idempotency_key: str) -> str:
"""Build the Redis key for an order idempotency marker."""
return f"{ORDER_IDEMPOTENCY_PREFIX}:{idempotency_key}"
async def check_idempotency_redis(
rds: aioredis.Redis,
idempotency_key: str,
) -> str | None:
"""Fast-path: check Redis for a previously processed idempotency key.
Returns the existing order_id if found, None otherwise.
"""
redis_key = _redis_idempotency_key(idempotency_key)
cached = await rds.get(redis_key)
if cached:
return str(cached)
return None
async def check_idempotency_db(
pool: asyncpg.Pool,
idempotency_key: str,
) -> dict[str, Any] | None:
"""Durable fallback: check PostgreSQL for an existing order with this key.
Returns a dict with id, status, broker_order_id if found, None otherwise.
"""
row = await pool.fetchrow(_CHECK_ORDER_BY_IDEMPOTENCY_KEY, idempotency_key)
if row:
return {
"id": str(row["id"]),
"status": str(row["status"]),
"broker_order_id": str(row["broker_order_id"] or ""),
}
return None
async def mark_idempotency_redis(
rds: aioredis.Redis,
idempotency_key: str,
order_id: str,
) -> None:
"""Set the Redis idempotency marker after an order is processed."""
redis_key = _redis_idempotency_key(idempotency_key)
await rds.set(redis_key, order_id, ex=ORDER_IDEMPOTENCY_TTL)
# ---------------------------------------------------------------------------
# Core service logic
# ---------------------------------------------------------------------------
def build_order_request(job: dict[str, Any]) -> OrderRequest:
"""Build an OrderRequest from a broker queue job payload."""
side = OrderSide.SELL if job.get("side", "buy") == "sell" else OrderSide.BUY
order_type_str = job.get("order_type", "market")
order_type_map = {
"market": OrderType.MARKET,
"limit": OrderType.LIMIT,
"stop": OrderType.STOP,
"stop_limit": OrderType.STOP_LIMIT,
}
return OrderRequest(
ticker=job["ticker"],
side=side,
quantity=float(job.get("quantity", 0)),
order_type=order_type_map.get(order_type_str, OrderType.MARKET),
limit_price=job.get("limit_price"),
stop_price=job.get("stop_price"),
time_in_force=job.get("time_in_force", "day"),
idempotency_key=generate_idempotency_key(job),
)
def build_proposed_order(job: dict[str, Any]) -> ProposedOrder:
"""Build a ProposedOrder for risk evaluation from a broker queue job."""
return ProposedOrder(
recommendation_id=job.get("recommendation_id"),
ticker=job["ticker"],
sector=job.get("sector", ""),
action=job.get("side", "buy"),
quantity=float(job.get("quantity", 0)),
estimated_value=float(job.get("estimated_value", 0)),
confidence=float(job.get("confidence", 0)),
)
async def load_risk_config(pool: asyncpg.Pool) -> PortfolioRiskConfig:
"""Load the active risk configuration from the database."""
row = await pool.fetchrow(_LOAD_RISK_CONFIG)
if row and row["config"]:
data = row["config"] if isinstance(row["config"], dict) else json.loads(row["config"])
return PortfolioRiskConfig.from_db_json(data)
return PortfolioRiskConfig()
async def load_account_risk_state(
pool: asyncpg.Pool,
adapter: AlpacaBrokerAdapter,
account_uuid: str,
) -> AccountRiskState:
"""Build an AccountRiskState from the broker and daily snapshot."""
state = AccountRiskState(account_id=account_uuid)
# Get live account info from Alpaca
try:
acct = await adapter.get_account()
state.portfolio_value = acct.portfolio_value
state.cash = acct.cash
state.buying_power = acct.buying_power
except Exception as e:
logger.warning("Failed to fetch account from Alpaca: %s", e)
# Get positions from Alpaca
try:
positions = await adapter.get_positions()
for pos in positions:
state.positions_by_symbol[pos.ticker] = pos.market_value
state.open_position_count = len(positions)
except Exception as e:
logger.warning("Failed to fetch positions from Alpaca: %s", e)
# Overlay daily snapshot from DB
row = await pool.fetchrow(_LOAD_DAILY_SNAPSHOT, account_uuid)
if row:
state.daily_pnl = float(row["daily_pnl"] or 0)
state.daily_trade_count = int(row["daily_trade_count"] or 0)
sector_data = row["positions_by_sector"]
if sector_data:
state.positions_by_sector = (
sector_data if isinstance(sector_data, dict) else json.loads(sector_data)
)
return state
async def persist_order(
pool: asyncpg.Pool,
order_id: str,
order: OrderRequest,
resp: OrderResponse,
account_uuid: str,
risk_eval: dict[str, Any],
recommendation_id: str | None = None,
) -> None:
"""Persist order, events, and risk evaluation to PostgreSQL."""
now = datetime.now(timezone.utc)
filled_at = now if resp.status == OrderStatus.FILLED else None
decision_trace = {
"risk_evaluation": risk_eval,
"order_request": order.to_dict(),
"broker_response": resp.to_dict(),
}
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
_INSERT_ORDER,
order_id,
recommendation_id,
account_uuid,
order.ticker,
order.side.value,
order.order_type.value,
order.quantity,
order.limit_price,
order.stop_price,
resp.status.value,
order.idempotency_key,
resp.broker_order_id,
json.dumps(decision_trace),
resp.submitted_at or now,
filled_at,
resp.filled_avg_price,
resp.filled_quantity,
)
# Record order events
for event_type in ["submitted"]:
await conn.execute(
_INSERT_ORDER_EVENT,
order_id,
event_type,
json.dumps({"ticker": order.ticker, "side": order.side.value}),
now,
)
if resp.status == OrderStatus.FILLED:
await conn.execute(
_INSERT_ORDER_EVENT,
order_id,
"fill",
json.dumps({
"fill_price": resp.filled_avg_price,
"fill_qty": resp.filled_quantity,
}),
now,
)
elif resp.status == OrderStatus.REJECTED:
await conn.execute(
_INSERT_ORDER_EVENT,
order_id,
"rejected",
json.dumps({"error": resp.error}),
now,
)
async def sync_positions(
adapter: AlpacaBrokerAdapter,
pool: asyncpg.Pool,
account_uuid: str,
minio_client: Any | None = None,
) -> None:
"""Sync current positions from Alpaca to PostgreSQL and publish to lake."""
now = datetime.now(timezone.utc)
try:
positions = await adapter.get_positions()
async with pool.acquire() as conn:
for pos in positions:
await conn.execute(
_UPSERT_POSITION,
account_uuid,
pos.ticker,
pos.quantity,
pos.avg_entry_price,
pos.current_price,
pos.unrealized_pnl,
now,
)
logger.info("Synced %d positions from Alpaca", len(positions))
POSITIONS_SYNCED.inc()
# Publish positions snapshot to analytical lake
if minio_client is not None and positions:
try:
pos_dicts = [
{
"ticker": p.ticker,
"quantity": p.quantity,
"avg_entry_price": p.avg_entry_price,
"close_price": p.current_price,
"unrealized_pnl": p.unrealized_pnl,
}
for p in positions
]
publish_positions_daily_batch(
minio_client, pos_dicts, account_uuid, now,
)
except Exception as e:
logger.warning("Failed to publish positions to lake: %s", e)
except Exception as e:
logger.error("Position sync failed: %s", e)
async def register_broker_account(
pool: asyncpg.Pool,
account_uuid: str,
adapter: AlpacaBrokerAdapter,
) -> None:
"""Register or update the broker account in PostgreSQL."""
try:
acct = await adapter.get_account()
config_json = json.dumps({
"provider": "alpaca",
"buying_power": acct.buying_power,
"cash": acct.cash,
"portfolio_value": acct.portfolio_value,
})
await pool.execute(
_UPSERT_BROKER_ACCOUNT,
account_uuid,
"alpaca",
acct.account_id or account_uuid,
adapter.mode.value,
config_json,
)
logger.info(
"Registered Alpaca account: id=%s mode=%s portfolio=%.2f",
acct.account_id, adapter.mode.value, acct.portfolio_value,
)
except Exception as e:
logger.error("Failed to register broker account: %s", e)
async def process_order_job(
job: dict[str, Any],
adapter: AlpacaBrokerAdapter,
pool: asyncpg.Pool,
account_uuid: str,
rds: aioredis.Redis | None = None,
minio_client: Any | None = None,
) -> None:
"""Process a single order job from the broker queue.
1. Generate deterministic idempotency key
2. Check Redis + DB for duplicate (Req 8.5)
3. Build proposed order and run risk evaluation
4. If risk passes, submit to Alpaca
5. Persist order, events, and risk evaluation
6. Set Redis idempotency marker
"""
ticker = job.get("ticker", "???")
order_id = str(uuid.uuid4())
idempotency_key = generate_idempotency_key(job)
# --- Duplicate prevention (Requirement 8.5) ---
# Fast path: Redis check
if rds is not None:
existing_order_id = await check_idempotency_redis(rds, idempotency_key)
if existing_order_id:
logger.info(
"Duplicate order detected (redis) for %s key=%s existing=%s",
ticker, idempotency_key[:16], existing_order_id,
)
ORDERS_DUPLICATES_PREVENTED.labels(detected_via="redis").inc()
await audit_duplicate_prevented(
pool, existing_order_id, ticker, idempotency_key, detected_via="redis",
)
return
# Durable fallback: DB check
existing = await check_idempotency_db(pool, idempotency_key)
if existing:
logger.info(
"Duplicate order detected (db) for %s key=%s existing=%s status=%s",
ticker, idempotency_key[:16], existing["id"], existing["status"],
)
ORDERS_DUPLICATES_PREVENTED.labels(detected_via="db").inc()
await audit_duplicate_prevented(
pool, existing["id"], ticker, idempotency_key, detected_via="db",
)
# Warm Redis cache for future fast-path hits
if rds is not None:
await mark_idempotency_redis(rds, idempotency_key, existing["id"])
return
# Risk evaluation
risk_config = await load_risk_config(pool)
risk_state = await load_account_risk_state(pool, adapter, account_uuid)
proposed = build_proposed_order(job)
evaluation = evaluate_order(proposed, risk_config, risk_state)
risk_eval_dict = {
"evaluation_id": evaluation.evaluation_id,
"eligible": evaluation.eligible,
"allowed_mode": evaluation.allowed_mode.value,
"rejection_reasons": evaluation.rejection_reasons,
"checks": [c.model_dump(mode="json") for c in evaluation.checks],
}
# Persist risk evaluation
rec_id = job.get("recommendation_id")
try:
await pool.execute(
_INSERT_RISK_EVALUATION,
evaluation.evaluation_id,
rec_id,
evaluation.eligible,
evaluation.allowed_mode.value,
json.dumps(evaluation.rejection_reasons),
json.dumps(risk_eval_dict["checks"]),
evaluation.evaluated_at,
)
except Exception as e:
logger.warning("Failed to persist risk evaluation: %s", e)
# Audit: risk evaluation result
await audit_risk_evaluated(
pool,
evaluation_id=evaluation.evaluation_id,
recommendation_id=rec_id,
ticker=ticker,
eligible=evaluation.eligible,
allowed_mode=evaluation.allowed_mode.value,
rejection_reasons=evaluation.rejection_reasons,
check_count=len(evaluation.checks),
)
if not evaluation.eligible:
RISK_EVALUATIONS_TOTAL.labels(result="rejected").inc()
for check in evaluation.checks:
if check.result.value == "fail":
RISK_CHECK_FAILURES.labels(check_name=check.check_name).inc()
ORDERS_REJECTED.labels(reason_category="risk_engine").inc()
logger.info(
"Order rejected by risk engine for %s: %s",
ticker, evaluation.rejection_reasons,
)
# Persist the rejected order for audit
order_req = build_order_request(job)
rejected_resp = OrderResponse(
broker_order_id="",
status=OrderStatus.REJECTED,
ticker=ticker,
side=OrderSide.SELL if job.get("side") == "sell" else OrderSide.BUY,
quantity=float(job.get("quantity", 0)),
error=f"Risk rejected: {'; '.join(evaluation.rejection_reasons)}",
)
await persist_order(
pool, order_id, order_req, rejected_resp,
account_uuid, risk_eval_dict, rec_id,
)
# Publish rejected order fact to analytical lake
if minio_client is not None:
try:
publish_trade_order(
minio_client, order_id, ticker,
side=job.get("side", "buy"),
order_type=job.get("order_type", "market"),
quantity=float(job.get("quantity", 0)),
limit_price=job.get("limit_price"),
status="rejected",
broker_account=account_uuid,
submitted_at=datetime.now(timezone.utc),
)
except Exception as e:
logger.warning("Failed to publish rejected order to lake: %s", e)
# Audit: order rejected by risk engine
await audit_order_rejected(
pool, order_id, ticker,
reason=f"Risk rejected: {'; '.join(evaluation.rejection_reasons)}",
source="risk_engine",
)
# Mark idempotency even for rejected orders to prevent reprocessing
if rds is not None:
await mark_idempotency_redis(rds, idempotency_key, order_id)
return
# --- Operator approval gate (Requirement 8.2) ---
if requires_approval(risk_config, evaluation.allowed_mode):
expiry = compute_expiry(risk_config)
approval_req = ApprovalRequest(
order_job=job,
recommendation_id=rec_id,
ticker=ticker,
side=job.get("side", "buy"),
quantity=float(job.get("quantity", 0)),
estimated_value=float(job.get("estimated_value", 0)),
risk_evaluation_id=evaluation.evaluation_id,
expires_at=expiry,
)
try:
await create_approval_request(pool, approval_req)
logger.info(
"Order for %s held for operator approval (id=%s, expires=%s)",
ticker, approval_req.approval_id, expiry.isoformat(),
)
await audit_approval_requested(
pool,
approval_id=approval_req.approval_id,
ticker=ticker,
side=approval_req.side,
quantity=approval_req.quantity,
estimated_value=approval_req.estimated_value,
recommendation_id=rec_id,
expires_at=expiry.isoformat(),
)
except Exception as e:
logger.error("Failed to create approval request for %s: %s", ticker, e)
# Do NOT mark idempotency — the job will be re-submitted after approval
return
# Submit to Alpaca
order_req = build_order_request(job)
RISK_EVALUATIONS_TOTAL.labels(result="passed").inc()
# Audit: order submitted to broker
await audit_order_submitted(
pool,
order_id=order_id,
ticker=ticker,
side=order_req.side.value,
quantity=order_req.quantity,
order_type=order_req.order_type.value,
idempotency_key=order_req.idempotency_key,
recommendation_id=rec_id,
evaluation_id=evaluation.evaluation_id,
)
resp = await adapter.submit_order(order_req)
await persist_order(
pool, order_id, order_req, resp,
account_uuid, risk_eval_dict, rec_id,
)
# Publish order fact to analytical lake
if minio_client is not None:
try:
publish_trade_order(
minio_client, order_id, ticker,
side=order_req.side.value,
order_type=order_req.order_type.value,
quantity=order_req.quantity,
limit_price=order_req.limit_price,
status=resp.status.value,
broker_account=account_uuid,
submitted_at=resp.submitted_at or datetime.now(timezone.utc),
)
except Exception as e:
logger.warning("Failed to publish order to lake: %s", e)
# Publish fill fact if the order was filled
if resp.status == OrderStatus.FILLED and resp.filled_avg_price is not None:
try:
fill_id = str(uuid.uuid4())
publish_trade_fill(
minio_client, fill_id, order_id, ticker,
side=order_req.side.value,
fill_price=resp.filled_avg_price,
fill_quantity=resp.filled_quantity,
broker_account=account_uuid,
filled_at=datetime.now(timezone.utc),
)
except Exception as e:
logger.warning("Failed to publish fill to lake: %s", e)
# Mark idempotency after successful persistence
if rds is not None:
await mark_idempotency_redis(rds, idempotency_key, order_id)
if resp.ok:
mode = "paper" if adapter.mode == TradingMode.PAPER else "live"
ORDERS_SUBMITTED.labels(
side=order_req.side.value,
order_type=order_req.order_type.value,
mode=mode,
).inc()
logger.info(
"Order submitted to Alpaca: %s %s %.0f %s @ %s | broker_id=%s",
resp.status.value, order_req.side.value, order_req.quantity,
ticker, resp.filled_avg_price, resp.broker_order_id,
)
# Audit: order filled
if resp.status == OrderStatus.FILLED:
ORDERS_FILLED.labels(side=order_req.side.value).inc()
await audit_order_filled(
pool, order_id, ticker,
side=order_req.side.value,
fill_quantity=resp.filled_quantity,
fill_price=resp.filled_avg_price,
broker_order_id=resp.broker_order_id,
)
else:
ORDERS_REJECTED.labels(reason_category="broker").inc()
logger.warning(
"Order failed for %s: %s (status=%s)",
ticker, resp.error, resp.status.value,
)
# Audit: order rejected by broker
await audit_order_rejected(
pool, order_id, ticker,
reason=resp.error or f"Broker status: {resp.status.value}",
source="broker",
)
async def position_sync_loop(
adapter: AlpacaBrokerAdapter,
pool: asyncpg.Pool,
account_uuid: str,
minio_client: Any | None = None,
) -> None:
"""Periodically sync positions from Alpaca to PostgreSQL and lake."""
while True:
await sync_positions(adapter, pool, account_uuid, minio_client)
await asyncio.sleep(POSITION_SYNC_INTERVAL)
async def main() -> None:
config = load_config()
setup_logging("broker_service", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
rds = get_redis(config)
# Initialize MinIO client for lake publishing
from minio import Minio
minio_client = Minio(
config.minio.endpoint,
access_key=config.minio.access_key,
secret_key=config.minio.secret_key,
secure=config.minio.secure,
)
# Ensure lakehouse bucket exists
if not minio_client.bucket_exists(LAKEHOUSE_BUCKET):
minio_client.make_bucket(LAKEHOUSE_BUCKET)
# Determine mode — default to paper for safety (Req 8.1)
mode = TradingMode.LIVE if config.broker.mode == "live" else TradingMode.PAPER
if mode == TradingMode.LIVE:
logger.warning("LIVE trading mode enabled — orders will be submitted to real broker")
adapter = AlpacaBrokerAdapter(
api_key=config.broker.api_key or "",
api_secret=config.broker.api_secret or "",
mode=mode,
base_url=config.broker.base_url,
)
# Generate a stable account UUID from the API key
account_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"alpaca-{config.broker.api_key or 'default'}"))
# Register broker account on startup
await register_broker_account(pool, account_uuid, adapter)
# Start position sync in background
sync_task = asyncio.create_task(
position_sync_loop(adapter, pool, account_uuid, minio_client)
)
queue = queue_key(QUEUE_BROKER)
logger.info("Broker service started (mode=%s)", mode.value)
try:
while True:
result = await rds.lpop(queue)
raw = str(result) if result else None
if raw:
try:
job = json.loads(raw)
await process_order_job(job, adapter, pool, account_uuid, rds, minio_client)
except Exception:
logger.exception("Error processing broker job")
else:
await asyncio.sleep(2)
finally:
sync_task.cancel()
await pool.close()
await rds.close()
if __name__ == "__main__":
asyncio.run(main())
+170 -27
View File
@@ -1,8 +1,17 @@
"""Filings / Regulatory API adapter - fetches SEC-style submissions."""
"""Filings / Regulatory API adapter interface and concrete SEC EDGAR provider.
The FilingsDataAdapter is the abstract interface for all filings data providers.
SECEdgarAdapter is the first concrete implementation, targeting the SEC EDGAR
full-text search system (EFTS) for company filings discovery.
Requirements: 2.3, 2.5, 3.1, 3.2, 3.3
"""
import hashlib
import logging
from datetime import datetime
from typing import Any, Dict
import time
from abc import ABC
from datetime import datetime, timezone
from typing import Any
import httpx
@@ -11,48 +20,182 @@ from .base import AdapterResult, BaseAdapter
logger = logging.getLogger("filings_adapter")
class FilingsAdapter(BaseAdapter):
"""Concrete adapter for SEC EDGAR or similar filings API."""
class FilingsDataAdapter(BaseAdapter, ABC):
"""Abstract interface for filings / regulatory data providers.
def __init__(self, base_url: str = "https://efts.sec.gov", user_agent: str = "StonksOracle/1.0"):
self.base_url = base_url
self.user_agent = user_agent
Subclasses implement fetch() for their specific filings API.
source_type() is concrete here since all filings adapters share the same type.
"""
def source_type(self) -> str:
return "filings_api"
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
_cik = config.get("cik", "")
endpoint = config.get("endpoint", f"/LATEST/search-index?q=%22{ticker}%22&dateRange=custom&startdt=2026-01-01&forms=8-K,10-Q,10-K")
url = f"{self.base_url}{endpoint}"
headers = {"User-Agent": self.user_agent}
class SECEdgarAdapter(FilingsDataAdapter):
"""Concrete adapter for the SEC EDGAR full-text search system (EFTS).
Supports:
- Full-text search (/LATEST/search-index) for 8-K, 10-Q, 10-K, and other forms
- Filtering by date range, form type, and entity
The SEC EDGAR EFTS API is public and does not require an API key,
but requires a descriptive User-Agent header per SEC fair-access policy.
Config options:
cik: Company CIK number (optional, narrows search)
forms: Comma-separated form types to search (default "8-K,10-Q,10-K")
start_date: Only filings on or after this date, YYYY-MM-DD (optional)
end_date: Only filings on or before this date, YYYY-MM-DD (optional)
query: Custom search query override (optional, replaces ticker-based query)
"""
SEARCH_ENDPOINT: str = "/LATEST/search-index"
def __init__(
self,
base_url: str = "https://efts.sec.gov",
user_agent: str = "StonksOracle/1.0 ([email])",
) -> None:
self.base_url: str = base_url.rstrip("/")
self.user_agent: str = user_agent
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch filings from SEC EDGAR EFTS for a given ticker.
Args:
ticker: The company ticker symbol.
config: Source-specific configuration from the sources table.
Returns:
AdapterResult with raw payload, parsed filing items, and metadata.
"""
url, params, headers = self._build_request(ticker, config)
async with httpx.AsyncClient(timeout=30) as client:
t0 = time.monotonic()
try:
resp = await client.get(url, headers=headers)
resp = await client.get(url, params=params, headers=headers)
elapsed_ms = (time.monotonic() - t0) * 1000
resp.raise_for_status()
raw = resp.content
data = resp.json()
content_hash = hashlib.sha256(raw).hexdigest()
items = self._extract_items(data)
hits = data.get("hits", {}).get("hits", [])
return AdapterResult(
source_type="filings_api",
ticker=ticker,
items=hits,
items=items,
raw_payload=raw,
content_hash=content_hash,
fetched_at=datetime.utcnow(),
fetched_at=datetime.now(timezone.utc),
http_status=resp.status_code,
response_time_ms=round(elapsed_ms, 1),
metadata={
"provider": "sec_edgar",
"results_count": len(items),
"total_hits": self._total_hits(data),
"query": params.get("q", ""),
"forms": params.get("forms", ""),
},
)
except httpx.HTTPStatusError as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("SEC EDGAR HTTP error for %s: %s", ticker, e)
return self._error_result(
ticker, str(e), elapsed_ms,
http_status=e.response.status_code if e.response else None,
raw=e.response.content if e.response else b"",
)
except httpx.TimeoutException as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("SEC EDGAR timeout for %s: %s", ticker, e)
return self._error_result(ticker, f"timeout: {e}", elapsed_ms)
except Exception as e:
logger.error(f"Filings fetch failed for {ticker}: {e}")
return AdapterResult(
source_type="filings_api",
ticker=ticker,
items=[],
raw_payload=b"",
content_hash="",
fetched_at=datetime.utcnow(),
error=str(e),
)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("SEC EDGAR fetch failed for %s: %s", ticker, e)
return self._error_result(ticker, str(e), elapsed_ms)
def _build_request(
self, ticker: str, config: dict[str, Any]
) -> tuple[str, dict[str, str], dict[str, str]]:
"""Build the URL, query params, and headers for an EDGAR EFTS request."""
params: dict[str, str] = {}
headers: dict[str, str] = {"User-Agent": self.user_agent}
# Query: use custom override or default to ticker-based search
query = config.get("query")
if query:
params["q"] = str(query)
else:
params["q"] = f'"{ticker}"'
# Form types filter
forms = config.get("forms", "8-K,10-Q,10-K")
params["forms"] = str(forms)
# Date range
if config.get("start_date"):
params["dateRange"] = "custom"
params["startdt"] = str(config["start_date"])
if config.get("end_date"):
params["dateRange"] = "custom"
params["enddt"] = str(config["end_date"])
# CIK filter (entity-level narrowing)
cik = config.get("cik")
if cik:
params["q"] = f'{params["q"]} AND cik:{cik}'
url = f"{self.base_url}{self.SEARCH_ENDPOINT}"
return url, params, headers
def _extract_items(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract the filing hits from an EDGAR EFTS response.
EFTS returns results under hits.hits as a list of objects,
each containing _source with fields like file_date, form_type,
entity_name, file_num, and period_of_report.
"""
hits_wrapper = data.get("hits", {})
if not isinstance(hits_wrapper, dict):
return []
hits = hits_wrapper.get("hits", [])
if isinstance(hits, list):
return hits
return []
def _total_hits(self, data: dict[str, Any]) -> int:
"""Extract total hit count from EFTS response."""
hits_wrapper = data.get("hits", {})
if not isinstance(hits_wrapper, dict):
return 0
total = hits_wrapper.get("total", {})
if isinstance(total, dict):
return int(total.get("value", 0))
if isinstance(total, int):
return total
return 0
def _error_result(
self,
ticker: str,
error: str,
elapsed_ms: float,
http_status: int | None = None,
raw: bytes = b"",
) -> AdapterResult:
"""Build an error AdapterResult for filings fetches."""
return AdapterResult(
source_type="filings_api",
ticker=ticker,
items=[],
raw_payload=raw,
content_hash="",
fetched_at=datetime.now(timezone.utc),
error=error,
http_status=http_status,
response_time_ms=round(elapsed_ms, 1),
metadata={"provider": "sec_edgar"},
)
+145 -27
View File
@@ -1,8 +1,16 @@
"""Market data API adapter - fetches quotes, bars, and reference data."""
"""Market data API adapter interface and concrete Polygon.io provider.
The MarketDataAdapter is the abstract interface for all market data providers.
PolygonMarketAdapter is the first concrete implementation, targeting the
Polygon.io REST API for previous-day bars, quotes, and ticker details.
Requirements: 2.1, 2.5, 3.1, 3.2, 3.3
"""
import hashlib
import logging
from datetime import datetime
from typing import Any, Dict
import time
from datetime import datetime, timezone
from typing import Any
import httpx
@@ -12,48 +20,158 @@ logger = logging.getLogger("market_adapter")
class MarketDataAdapter(BaseAdapter):
"""Concrete adapter for a market data provider (e.g., Alpha Vantage, Polygon, Yahoo)."""
"""Abstract interface for market data providers.
def __init__(self, api_key: str = "", base_url: str = ""):
self.api_key = api_key
self.base_url = base_url
Subclasses implement fetch() for their specific market data API.
"""
def source_type(self) -> str:
return "market_api"
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
endpoint = config.get("endpoint", "/v2/aggs/ticker/{ticker}/prev")
url = f"{self.base_url}{endpoint.format(ticker=ticker)}"
params = config.get("params", {})
if self.api_key:
params["apiKey"] = self.api_key
class PolygonMarketAdapter(MarketDataAdapter):
"""Concrete adapter for the Polygon.io REST API.
Supports:
- Previous-day aggregate bars (/v2/aggs/ticker/{ticker}/prev)
- Grouped daily bars (/v2/aggs/grouped/locale/us/market/stocks/{date})
- Ticker details (/v3/reference/tickers/{ticker})
The endpoint is selected via the source config's "endpoint" field,
defaulting to previous-day bars.
"""
PREV_BARS = "/v2/aggs/ticker/{ticker}/prev"
RANGE_BARS = "/v2/aggs/ticker/{ticker}/range/{multiplier}/{timespan}/{from_date}/{to_date}"
TICKER_DETAILS = "/v3/reference/tickers/{ticker}"
def __init__(self, api_key: str, base_url: str = "https://api.polygon.io") -> None:
self.api_key: str = api_key
self.base_url: str = base_url.rstrip("/")
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch market data from Polygon.io for a given ticker.
Config options:
endpoint: One of "prev_bars" (default), "range_bars", "ticker_details"
multiplier: Bar multiplier for range queries (default 1)
timespan: Bar timespan for range queries (default "day")
from_date: Start date for range queries (YYYY-MM-DD)
to_date: End date for range queries (YYYY-MM-DD)
adjusted: Whether bars are adjusted for splits (default true)
"""
endpoint_key = config.get("endpoint", "prev_bars")
url, params = self._build_request(ticker, endpoint_key, config)
async with httpx.AsyncClient(timeout=30) as client:
t0 = time.monotonic()
try:
resp = await client.get(url, params=params)
elapsed_ms = (time.monotonic() - t0) * 1000
resp.raise_for_status()
raw = resp.content
data = resp.json()
content_hash = hashlib.sha256(raw).hexdigest()
items = data.get("results", [data]) if isinstance(data, dict) else data
items = self._extract_items(data, endpoint_key)
return AdapterResult(
source_type="market_api",
ticker=ticker,
items=items if isinstance(items, list) else [items],
items=items,
raw_payload=raw,
content_hash=content_hash,
fetched_at=datetime.utcnow(),
fetched_at=datetime.now(timezone.utc),
http_status=resp.status_code,
response_time_ms=round(elapsed_ms, 1),
metadata={
"provider": "polygon",
"endpoint": endpoint_key,
"results_count": data.get("resultsCount", len(items)),
"request_id": data.get("request_id", ""),
},
)
except httpx.HTTPStatusError as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Polygon HTTP error for %s: %s", ticker, e)
return self._error_result(
ticker, str(e), elapsed_ms,
http_status=e.response.status_code if e.response else None,
raw=e.response.content if e.response else b"",
)
except httpx.TimeoutException as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Polygon timeout for %s: %s", ticker, e)
return self._error_result(ticker, f"timeout: {e}", elapsed_ms)
except Exception as e:
logger.error(f"Market fetch failed for {ticker}: {e}")
return AdapterResult(
source_type="market_api",
ticker=ticker,
items=[],
raw_payload=b"",
content_hash="",
fetched_at=datetime.utcnow(),
error=str(e),
)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Polygon fetch failed for %s: %s", ticker, e)
return self._error_result(ticker, str(e), elapsed_ms)
def _build_request(
self, ticker: str, endpoint_key: str, config: dict[str, Any]
) -> tuple[str, dict[str, str]]:
"""Build the URL and query params for a Polygon request."""
params: dict[str, str] = {"apiKey": self.api_key}
if endpoint_key == "range_bars":
multiplier = str(config.get("multiplier", 1))
timespan = config.get("timespan", "day")
from_date = config.get("from_date", "")
to_date = config.get("to_date", "")
path = self.RANGE_BARS.format(
ticker=ticker,
multiplier=multiplier,
timespan=timespan,
from_date=from_date,
to_date=to_date,
)
if config.get("adjusted") is not None:
params["adjusted"] = str(config["adjusted"]).lower()
if config.get("sort"):
params["sort"] = config["sort"]
if config.get("limit"):
params["limit"] = str(config["limit"])
elif endpoint_key == "ticker_details":
path = self.TICKER_DETAILS.format(ticker=ticker)
else:
# Default: previous-day bars
path = self.PREV_BARS.format(ticker=ticker)
if config.get("adjusted") is not None:
params["adjusted"] = str(config["adjusted"]).lower()
return f"{self.base_url}{path}", params
def _extract_items(self, data: dict[str, Any], endpoint_key: str) -> list[dict[str, Any]]:
"""Extract the relevant items list from a Polygon response."""
if endpoint_key == "ticker_details":
results = data.get("results", {})
return [results] if isinstance(results, dict) and results else []
# Aggregate endpoints return results as a list
results = data.get("results", [])
if isinstance(results, list):
return results
return [results] if results else []
def _error_result(
self,
ticker: str,
error: str,
elapsed_ms: float,
http_status: int | None = None,
raw: bytes = b"",
) -> AdapterResult:
"""Build an error AdapterResult."""
return AdapterResult(
source_type="market_api",
ticker=ticker,
items=[],
raw_payload=raw,
content_hash="",
fetched_at=datetime.now(timezone.utc),
error=error,
http_status=http_status,
response_time_ms=round(elapsed_ms, 1),
metadata={"provider": "polygon"},
)
+135 -30
View File
@@ -1,8 +1,17 @@
"""News API adapter - fetches company-linked headlines and article metadata."""
"""News API adapter interface and concrete Polygon.io news provider.
The NewsDataAdapter is the abstract interface for all news data providers.
PolygonNewsAdapter is the first concrete implementation, targeting the
Polygon.io REST API for company-linked news articles and headlines.
Requirements: 2.2, 2.5, 3.1, 3.2, 3.3
"""
import hashlib
import logging
from datetime import datetime
from typing import Any, Dict
import time
from abc import ABC
from datetime import datetime, timezone
from typing import Any
import httpx
@@ -11,51 +20,147 @@ from .base import AdapterResult, BaseAdapter
logger = logging.getLogger("news_adapter")
class NewsApiAdapter(BaseAdapter):
"""Concrete adapter for a news API provider."""
class NewsDataAdapter(BaseAdapter, ABC):
"""Abstract interface for news data providers.
def __init__(self, api_key: str = "", base_url: str = ""):
self.api_key = api_key
self.base_url = base_url
Subclasses implement fetch() for their specific news API.
source_type() is concrete here since all news adapters share the same type.
"""
def source_type(self) -> str:
return "news_api"
async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult:
endpoint = config.get("endpoint", "/v2/everything")
url = f"{self.base_url}{endpoint}"
params = config.get("params", {})
params.setdefault("q", ticker)
params.setdefault("sortBy", "publishedAt")
params.setdefault("pageSize", 20)
if self.api_key:
params["apiKey"] = self.api_key
class PolygonNewsAdapter(NewsDataAdapter):
"""Concrete adapter for the Polygon.io ticker news endpoint.
Supports:
- Ticker news (/v2/reference/news?ticker={ticker})
Config options:
limit: Max articles to return per request (default 20, max 1000)
published_utc_gte: Only articles published on or after this date (YYYY-MM-DD)
published_utc_lte: Only articles published on or before this date (YYYY-MM-DD)
order: Sort order for results, "asc" or "desc" (default "desc")
"""
NEWS_ENDPOINT = "/v2/reference/news"
def __init__(self, api_key: str, base_url: str = "https://api.polygon.io") -> None:
self.api_key: str = api_key
self.base_url: str = base_url.rstrip("/")
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch news articles from Polygon.io for a given ticker.
Args:
ticker: The company ticker symbol.
config: Source-specific configuration from the sources table.
Returns:
AdapterResult with raw payload, parsed article items, and metadata.
"""
url, params = self._build_request(ticker, config)
async with httpx.AsyncClient(timeout=30) as client:
t0 = time.monotonic()
try:
resp = await client.get(url, params=params)
elapsed_ms = (time.monotonic() - t0) * 1000
resp.raise_for_status()
raw = resp.content
data = resp.json()
content_hash = hashlib.sha256(raw).hexdigest()
items = self._extract_items(data)
articles = data.get("articles", [])
return AdapterResult(
source_type="news_api",
ticker=ticker,
items=articles,
items=items,
raw_payload=raw,
content_hash=content_hash,
fetched_at=datetime.utcnow(),
fetched_at=datetime.now(timezone.utc),
http_status=resp.status_code,
response_time_ms=round(elapsed_ms, 1),
metadata={
"provider": "polygon",
"results_count": data.get("count", len(items)),
"next_url": data.get("next_url", ""),
"request_id": data.get("request_id", ""),
},
)
except httpx.HTTPStatusError as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Polygon news HTTP error for %s: %s", ticker, e)
return self._error_result(
ticker, str(e), elapsed_ms,
http_status=e.response.status_code if e.response else None,
raw=e.response.content if e.response else b"",
)
except httpx.TimeoutException as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Polygon news timeout for %s: %s", ticker, e)
return self._error_result(ticker, f"timeout: {e}", elapsed_ms)
except Exception as e:
logger.error(f"News fetch failed for {ticker}: {e}")
return AdapterResult(
source_type="news_api",
ticker=ticker,
items=[],
raw_payload=b"",
content_hash="",
fetched_at=datetime.utcnow(),
error=str(e),
)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Polygon news fetch failed for %s: %s", ticker, e)
return self._error_result(ticker, str(e), elapsed_ms)
def _build_request(
self, ticker: str, config: dict[str, Any]
) -> tuple[str, dict[str, str]]:
"""Build the URL and query params for a Polygon news request."""
params: dict[str, str] = {
"apiKey": self.api_key,
"ticker": ticker,
}
limit = config.get("limit", 20)
params["limit"] = str(min(int(limit), 1000))
if config.get("order"):
params["order"] = config["order"]
if config.get("published_utc_gte"):
params["published_utc.gte"] = config["published_utc_gte"]
if config.get("published_utc_lte"):
params["published_utc.lte"] = config["published_utc_lte"]
url = f"{self.base_url}{self.NEWS_ENDPOINT}"
return url, params
def _extract_items(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract the article list from a Polygon news response.
Polygon returns articles under the "results" key as a list of objects,
each containing fields like id, publisher, title, article_url, tickers,
published_utc, description, and keywords.
"""
results = data.get("results", [])
if isinstance(results, list):
return results
return []
def _error_result(
self,
ticker: str,
error: str,
elapsed_ms: float,
http_status: int | None = None,
raw: bytes = b"",
) -> AdapterResult:
"""Build an error AdapterResult for news fetches."""
return AdapterResult(
source_type="news_api",
ticker=ticker,
items=[],
raw_payload=raw,
content_hash="",
fetched_at=datetime.now(timezone.utc),
error=error,
http_status=http_status,
response_time_ms=round(elapsed_ms, 1),
metadata={"provider": "polygon"},
)
+603
View File
@@ -0,0 +1,603 @@
"""Paper trading adapter - local order simulation and state sync.
Implements a fully local paper trading engine that simulates order
execution without requiring a real broker API. Tracks positions,
account balance, fills, and order events in-memory with PostgreSQL
persistence for state sync and audit trail.
Requirements: 8.1, 8.3, 8.5, 2.4
Design: Section 4.9 - Broker Adapter (paper mode)
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
import asyncpg
from services.adapters.broker_adapter import (
AccountInfo,
BrokerDataAdapter,
OrderEventType,
OrderRequest,
OrderResponse,
OrderSide,
OrderStatus,
OrderType,
PositionInfo,
TradingMode,
)
from services.adapters.base import AdapterResult
logger = logging.getLogger("paper_trading")
# ---------------------------------------------------------------------------
# In-memory paper trading state
# ---------------------------------------------------------------------------
class PaperPosition:
"""Tracks a single paper position."""
def __init__(
self,
ticker: str,
quantity: float = 0.0,
avg_entry_price: float = 0.0,
realized_pnl: float = 0.0,
) -> None:
self.ticker = ticker
self.quantity = quantity
self.avg_entry_price = avg_entry_price
self.realized_pnl = realized_pnl
def apply_fill(self, side: OrderSide, fill_qty: float, fill_price: float) -> float:
"""Apply a fill to this position. Returns realized PnL from the fill."""
realized = 0.0
if side == OrderSide.BUY:
# Buying: average up the entry price
total_cost = self.avg_entry_price * self.quantity + fill_price * fill_qty
self.quantity += fill_qty
if self.quantity > 0:
self.avg_entry_price = total_cost / self.quantity
else:
# Selling: realize PnL on the sold shares
if self.quantity > 0:
sell_qty = min(fill_qty, self.quantity)
realized = sell_qty * (fill_price - self.avg_entry_price)
self.quantity -= sell_qty
self.realized_pnl += realized
if self.quantity <= 0:
self.quantity = 0.0
self.avg_entry_price = 0.0
return realized
@property
def is_open(self) -> bool:
return self.quantity > 0
def to_position_info(self, current_price: float | None = None) -> PositionInfo:
"""Convert to a PositionInfo for the broker interface."""
price = current_price if current_price is not None else self.avg_entry_price
unrealized = (price - self.avg_entry_price) * self.quantity if self.quantity > 0 else 0.0
market_value = price * self.quantity
return PositionInfo(
ticker=self.ticker,
quantity=self.quantity,
avg_entry_price=self.avg_entry_price,
current_price=price,
unrealized_pnl=round(unrealized, 4),
market_value=round(market_value, 4),
side="long" if self.quantity > 0 else "flat",
)
class PaperAccount:
"""In-memory paper trading account state."""
def __init__(
self,
account_id: str = "paper-default",
initial_cash: float = 100_000.0,
) -> None:
self.account_id = account_id
self.initial_cash = initial_cash
self.cash = initial_cash
self.positions: dict[str, PaperPosition] = {}
self.orders: dict[str, OrderResponse] = {}
self.order_events: list[dict[str, Any]] = []
self._seen_idempotency_keys: dict[str, str] = {} # key -> order_id
@property
def portfolio_value(self) -> float:
position_value = sum(
p.quantity * p.avg_entry_price for p in self.positions.values() if p.is_open
)
return self.cash + position_value
@property
def buying_power(self) -> float:
return self.cash
def get_position(self, ticker: str) -> PaperPosition:
if ticker not in self.positions:
self.positions[ticker] = PaperPosition(ticker=ticker)
return self.positions[ticker]
def to_account_info(self) -> AccountInfo:
return AccountInfo(
account_id=self.account_id,
buying_power=round(self.buying_power, 2),
cash=round(self.cash, 2),
portfolio_value=round(self.portfolio_value, 2),
currency="USD",
mode=TradingMode.PAPER,
)
# ---------------------------------------------------------------------------
# Paper trading adapter
# ---------------------------------------------------------------------------
class PaperTradingAdapter(BrokerDataAdapter):
"""Local paper trading adapter that simulates order execution.
All orders are filled immediately at the estimated price (market orders)
or at the limit/stop price when applicable. No real broker API is called.
Features:
- Idempotent order submission via idempotency_key (Req 8.5)
- Full order event trail for audit (Req 8.3)
- Position tracking with average entry price
- Cash balance management
- State sync to/from PostgreSQL
The adapter operates in PAPER mode only and rejects any attempt
to switch to LIVE mode.
"""
def __init__(
self,
account_id: str = "paper-default",
initial_cash: float = 100_000.0,
simulated_slippage_pct: float = 0.001,
) -> None:
super().__init__(mode=TradingMode.PAPER)
self.account = PaperAccount(account_id=account_id, initial_cash=initial_cash)
self.slippage_pct = simulated_slippage_pct
def source_type(self) -> str:
return "broker"
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch paper positions/account as a raw artifact snapshot."""
endpoint = config.get("endpoint", "positions")
now = datetime.now(timezone.utc)
if endpoint == "account":
data = self.account.to_account_info().to_dict()
items = [data]
elif endpoint == "orders":
items = [
resp.to_dict()
for resp in self.account.orders.values()
if resp.ticker == ticker or ticker == "*"
]
else:
pos = self.account.get_position(ticker)
data = pos.to_position_info().to_dict()
items = [data] if pos.is_open else []
raw = json.dumps(items).encode()
return AdapterResult(
source_type="broker",
ticker=ticker,
items=items,
raw_payload=raw,
content_hash="",
fetched_at=now,
metadata={"provider": "paper", "mode": "paper", "endpoint": endpoint},
)
async def submit_order(self, order: OrderRequest) -> OrderResponse:
"""Simulate order submission and immediate fill.
Idempotency: if the same idempotency_key was already used,
return the original response (Req 8.5).
"""
# Idempotency check
existing_id = self.account._seen_idempotency_keys.get(order.idempotency_key)
if existing_id and existing_id in self.account.orders:
logger.info("Duplicate order key %s — returning cached response", order.idempotency_key)
return self.account.orders[existing_id]
now = datetime.now(timezone.utc)
order_id = str(uuid.uuid4())
# Determine fill price based on order type
fill_price = self._compute_fill_price(order)
# Check if we have enough cash for buys
if order.side == OrderSide.BUY:
required_cash = fill_price * order.quantity
if required_cash > self.account.cash:
resp = OrderResponse(
broker_order_id=order_id,
status=OrderStatus.REJECTED,
ticker=order.ticker,
side=order.side,
quantity=order.quantity,
submitted_at=now,
error=f"Insufficient cash: need {required_cash:.2f}, have {self.account.cash:.2f}",
)
self._record_event(order_id, OrderEventType.REJECTED, resp.to_dict(), now)
self.account.orders[order_id] = resp
self.account._seen_idempotency_keys[order.idempotency_key] = order_id
return resp
# Check if we have enough shares for sells
if order.side == OrderSide.SELL:
pos = self.account.get_position(order.ticker)
if pos.quantity < order.quantity:
resp = OrderResponse(
broker_order_id=order_id,
status=OrderStatus.REJECTED,
ticker=order.ticker,
side=order.side,
quantity=order.quantity,
submitted_at=now,
error=f"Insufficient shares: need {order.quantity}, have {pos.quantity}",
)
self._record_event(order_id, OrderEventType.REJECTED, resp.to_dict(), now)
self.account.orders[order_id] = resp
self.account._seen_idempotency_keys[order.idempotency_key] = order_id
return resp
# Simulate immediate fill
position = self.account.get_position(order.ticker)
realized_pnl = position.apply_fill(order.side, order.quantity, fill_price)
# Update cash
if order.side == OrderSide.BUY:
self.account.cash -= fill_price * order.quantity
else:
self.account.cash += fill_price * order.quantity
resp = OrderResponse(
broker_order_id=order_id,
status=OrderStatus.FILLED,
ticker=order.ticker,
side=order.side,
quantity=order.quantity,
filled_quantity=order.quantity,
filled_avg_price=fill_price,
submitted_at=now,
raw_response={
"realized_pnl": round(realized_pnl, 4),
"cash_after": round(self.account.cash, 2),
"position_qty_after": position.quantity,
"simulated": True,
},
)
# Record events
self._record_event(order_id, OrderEventType.SUBMITTED, {"ticker": order.ticker}, now)
self._record_event(order_id, OrderEventType.ACCEPTED, {"ticker": order.ticker}, now)
self._record_event(order_id, OrderEventType.FILL, {
"fill_price": fill_price,
"fill_qty": order.quantity,
"realized_pnl": round(realized_pnl, 4),
}, now)
self.account.orders[order_id] = resp
self.account._seen_idempotency_keys[order.idempotency_key] = order_id
logger.info(
"Paper fill: %s %s %.0f %s @ %.2f | cash=%.2f pnl=%.4f",
order_id[:8], order.side.value, order.quantity,
order.ticker, fill_price, self.account.cash, realized_pnl,
)
return resp
async def cancel_order(self, broker_order_id: str) -> OrderResponse:
"""Cancel a paper order. Only pending orders can be cancelled."""
existing = self.account.orders.get(broker_order_id)
if existing is None:
return OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.REJECTED,
ticker="",
side=OrderSide.BUY,
quantity=0,
error=f"Order {broker_order_id} not found",
)
# Paper orders fill immediately, so they can't be cancelled
if existing.status == OrderStatus.FILLED:
return OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.REJECTED,
ticker=existing.ticker,
side=existing.side,
quantity=existing.quantity,
error="Cannot cancel a filled order",
)
now = datetime.now(timezone.utc)
cancelled = OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.CANCELLED,
ticker=existing.ticker,
side=existing.side,
quantity=existing.quantity,
submitted_at=existing.submitted_at,
)
self.account.orders[broker_order_id] = cancelled
self._record_event(broker_order_id, OrderEventType.CANCELLED, {}, now)
return cancelled
async def get_order_status(self, broker_order_id: str) -> OrderResponse:
"""Get the status of a paper order."""
existing = self.account.orders.get(broker_order_id)
if existing is None:
return OrderResponse(
broker_order_id=broker_order_id,
status=OrderStatus.REJECTED,
ticker="",
side=OrderSide.BUY,
quantity=0,
error=f"Order {broker_order_id} not found",
)
return existing
async def get_positions(self) -> list[PositionInfo]:
"""Get all open paper positions."""
return [
p.to_position_info()
for p in self.account.positions.values()
if p.is_open
]
async def get_account(self) -> AccountInfo:
"""Get paper account summary."""
return self.account.to_account_info()
# -----------------------------------------------------------------------
# Internal helpers
# -----------------------------------------------------------------------
def _compute_fill_price(self, order: OrderRequest) -> float:
"""Determine the simulated fill price for an order.
Market orders use the limit_price as a proxy (or 0 if not set).
Limit orders fill at the limit price.
Stop orders fill at the stop price.
A small slippage is applied to market orders.
"""
if order.order_type == OrderType.LIMIT and order.limit_price is not None:
return order.limit_price
if order.order_type == OrderType.STOP and order.stop_price is not None:
return order.stop_price
if order.order_type == OrderType.STOP_LIMIT and order.limit_price is not None:
return order.limit_price
# Market order: use limit_price as estimate, or a default
base_price = order.limit_price if order.limit_price is not None else 100.0
if order.side == OrderSide.BUY:
return round(base_price * (1 + self.slippage_pct), 4)
return round(base_price * (1 - self.slippage_pct), 4)
def _record_event(
self,
order_id: str,
event_type: OrderEventType,
data: dict[str, Any],
timestamp: datetime,
) -> None:
"""Record an order event for audit trail."""
self.account.order_events.append({
"order_id": order_id,
"event_type": event_type.value,
"data": data,
"timestamp": timestamp.isoformat(),
})
# ---------------------------------------------------------------------------
# State sync: persist and restore paper trading state to/from PostgreSQL
# ---------------------------------------------------------------------------
# SQL for persisting paper orders to the orders table
_INSERT_PAPER_ORDER = """
INSERT INTO orders (
id, recommendation_id, broker_account_id, ticker, side, order_type,
quantity, limit_price, stop_price, status, idempotency_key,
broker_order_id, decision_trace, submitted_at, filled_at,
fill_price, fill_quantity
) VALUES (
$1::uuid, $2, $3, $4, $5, $6,
$7, $8, $9, $10, $11,
$12, $13::jsonb, $14, $15,
$16, $17
)
ON CONFLICT (idempotency_key) DO NOTHING
"""
_INSERT_PAPER_ORDER_EVENT = """
INSERT INTO order_events (order_id, event_type, data, broker_timestamp)
VALUES ($1::uuid, $2, $3::jsonb, $4)
"""
_UPSERT_PAPER_POSITION = """
INSERT INTO positions (broker_account_id, ticker, quantity, avg_entry_price, realized_pnl, updated_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (broker_account_id, ticker)
DO UPDATE SET
quantity = EXCLUDED.quantity,
avg_entry_price = EXCLUDED.avg_entry_price,
realized_pnl = EXCLUDED.realized_pnl,
updated_at = EXCLUDED.updated_at
"""
_UPSERT_PAPER_ACCOUNT = """
INSERT INTO broker_accounts (id, provider, account_id, mode, config, active)
VALUES ($1::uuid, 'paper', $2, 'paper', $3::jsonb, TRUE)
ON CONFLICT (id) DO UPDATE SET
config = EXCLUDED.config,
active = TRUE
"""
_LOAD_PAPER_POSITIONS = """
SELECT ticker, quantity, avg_entry_price, COALESCE(realized_pnl, 0) AS realized_pnl
FROM positions
WHERE broker_account_id = $1 AND quantity > 0
"""
_LOAD_PAPER_ACCOUNT_CONFIG = """
SELECT config FROM broker_accounts
WHERE account_id = $1 AND mode = 'paper' AND active = TRUE
LIMIT 1
"""
_LOAD_PAPER_ORDERS = """
SELECT
id, ticker, side, order_type, quantity, status,
idempotency_key, broker_order_id, fill_price, fill_quantity,
submitted_at
FROM orders
WHERE broker_account_id = (
SELECT id FROM broker_accounts WHERE account_id = $1 AND mode = 'paper' LIMIT 1
)
ORDER BY submitted_at DESC
LIMIT 500
"""
async def sync_state_to_db(
adapter: PaperTradingAdapter,
pool: asyncpg.Pool,
broker_account_uuid: str | None = None,
) -> None:
"""Persist the current paper trading state to PostgreSQL.
Writes:
- broker_accounts row for the paper account
- positions rows for all open positions
- orders rows for all orders (idempotent via ON CONFLICT)
- order_events for audit trail
This enables state recovery after restarts and provides the
full execution audit trail (Requirement 8.3).
"""
acct = adapter.account
now = datetime.now(timezone.utc)
acct_uuid = broker_account_uuid or str(uuid.uuid5(uuid.NAMESPACE_DNS, acct.account_id))
async with pool.acquire() as conn:
async with conn.transaction():
# 1. Upsert broker account
config_json = json.dumps({
"initial_cash": acct.initial_cash,
"current_cash": round(acct.cash, 2),
"portfolio_value": round(acct.portfolio_value, 2),
"slippage_pct": adapter.slippage_pct,
})
await conn.execute(_UPSERT_PAPER_ACCOUNT, acct_uuid, acct.account_id, config_json)
# 2. Upsert positions
for ticker, pos in acct.positions.items():
await conn.execute(
_UPSERT_PAPER_POSITION,
acct_uuid, ticker,
pos.quantity, pos.avg_entry_price, pos.realized_pnl,
now,
)
# 3. Insert orders (idempotent)
for order_id, resp in acct.orders.items():
filled_at = now if resp.status == OrderStatus.FILLED else None
await conn.execute(
_INSERT_PAPER_ORDER,
order_id,
None, # recommendation_id
acct_uuid,
resp.ticker,
resp.side.value,
"market", # paper orders are always market-simulated
resp.quantity,
resp.filled_avg_price, # limit_price
None, # stop_price
resp.status.value,
order_id, # use order_id as idempotency_key fallback
order_id,
json.dumps(resp.raw_response),
resp.submitted_at,
filled_at,
resp.filled_avg_price,
resp.filled_quantity,
)
# 4. Insert order events
for event in acct.order_events:
await conn.execute(
_INSERT_PAPER_ORDER_EVENT,
event["order_id"],
event["event_type"],
json.dumps(event["data"]),
datetime.fromisoformat(event["timestamp"]),
)
logger.info(
"Synced paper state to DB: account=%s positions=%d orders=%d events=%d",
acct.account_id, len(acct.positions), len(acct.orders), len(acct.order_events),
)
# Clear events after sync to avoid re-inserting
acct.order_events.clear()
async def load_state_from_db(
adapter: PaperTradingAdapter,
pool: asyncpg.Pool,
) -> bool:
"""Restore paper trading state from PostgreSQL.
Loads positions and account config from the DB so the adapter
can resume after a restart. Returns True if state was found.
"""
acct = adapter.account
async with pool.acquire() as conn:
# Load account config
row = await conn.fetchrow(_LOAD_PAPER_ACCOUNT_CONFIG, acct.account_id)
if row is None:
logger.info("No saved paper account state for %s", acct.account_id)
return False
config = json.loads(row["config"]) if isinstance(row["config"], str) else row["config"]
acct.cash = float(config.get("current_cash", acct.initial_cash))
# Load positions
pos_rows = await conn.fetch(_LOAD_PAPER_POSITIONS, acct.account_id)
for pr in pos_rows:
ticker = pr["ticker"]
acct.positions[ticker] = PaperPosition(
ticker=ticker,
quantity=float(pr["quantity"]),
avg_entry_price=float(pr["avg_entry_price"] or 0),
realized_pnl=float(pr["realized_pnl"]),
)
logger.info(
"Loaded paper state from DB: account=%s cash=%.2f positions=%d",
acct.account_id, acct.cash, len(acct.positions),
)
return True
+241
View File
@@ -0,0 +1,241 @@
"""Resilient adapter wrapper with rate-limit coordination, retries, and backoff.
Wraps any BaseAdapter with:
- Per-source-type rate limiting via Redis (distributed across workers)
- Exponential backoff with jitter on retryable failures
- Configurable retry counts and retryable HTTP status codes
- Graceful degradation when Redis is unavailable
Requirements: 2.5, 3.4
"""
import asyncio
import logging
import random
import time
from dataclasses import dataclass
from typing import Any
import redis.asyncio as aioredis
from services.shared.redis_keys import rate_limit_key
from .base import AdapterResult, BaseAdapter
logger = logging.getLogger("resilient_adapter")
# HTTP status codes that are safe to retry
RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
@dataclass
class RetryConfig:
"""Configuration for retry and rate-limit behavior."""
max_retries: int = 3
base_delay: float = 1.0
max_delay: float = 60.0
jitter_factor: float = 0.5
retryable_status_codes: frozenset[int] = RETRYABLE_STATUS_CODES
# Rate limit: max requests per window per source type
rate_limit_max: int = 30
rate_limit_window_seconds: int = 60
# Sensible defaults per source type
DEFAULT_RETRY_CONFIGS: dict[str, RetryConfig] = {
"market_api": RetryConfig(max_retries=3, rate_limit_max=30),
"news_api": RetryConfig(max_retries=3, rate_limit_max=20),
"filings_api": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0),
"web_scrape": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0),
"broker": RetryConfig(max_retries=2, rate_limit_max=60, base_delay=0.5),
}
def compute_delay(attempt: int, config: RetryConfig) -> float:
"""Compute backoff delay with jitter for a given attempt number."""
exp_delay = config.base_delay * (2 ** attempt)
capped = min(exp_delay, config.max_delay)
jitter = capped * config.jitter_factor * random.random()
return capped + jitter
@dataclass
class RetryStats:
"""Tracks retry statistics for observability."""
attempts: int = 0
total_delay: float = 0.0
rate_limited_waits: int = 0
last_error: str | None = None
retryable: bool = False
class ResilientAdapter:
"""Wraps a BaseAdapter with rate-limit coordination, retries, and backoff.
Usage:
adapter = PolygonMarketAdapter(api_key="...")
resilient = ResilientAdapter(adapter, redis=rds)
result = await resilient.fetch(ticker, config)
If redis is None, rate limiting is skipped (local dev / testing).
"""
def __init__(
self,
adapter: BaseAdapter,
redis: aioredis.Redis | None = None,
retry_config: RetryConfig | None = None,
) -> None:
self._adapter = adapter
self._redis = redis
source_type = adapter.source_type()
self._config = retry_config or DEFAULT_RETRY_CONFIGS.get(
source_type, RetryConfig()
)
@property
def adapter(self) -> BaseAdapter:
"""Access the underlying adapter."""
return self._adapter
@property
def config(self) -> RetryConfig:
return self._config
def source_type(self) -> str:
return self._adapter.source_type()
async def _check_rate_limit(self) -> float:
"""Check distributed rate limit via Redis.
Returns 0.0 if allowed, or the number of seconds to wait.
"""
if self._redis is None:
return 0.0
source_type = self._adapter.source_type()
window_sec = self._config.rate_limit_window_seconds
# Use a time-bucketed key so counters auto-expire
bucket = int(time.time()) // window_sec
key = rate_limit_key(source_type, str(bucket))
try:
count = await self._redis.incr(key)
if count == 1:
await self._redis.expire(key, window_sec * 2)
if count > self._config.rate_limit_max:
# Over limit — compute how long until the window rolls over
elapsed_in_window = time.time() % window_sec
wait = window_sec - elapsed_in_window
return max(wait, 0.5)
except Exception:
# Redis unavailable — degrade gracefully, allow the request
logger.warning("Redis rate-limit check failed, allowing request")
return 0.0
def _is_retryable(self, result: AdapterResult) -> bool:
"""Determine if a failed result is worth retrying."""
if result.ok:
return False
# Retry on known retryable HTTP status codes
if result.http_status and result.http_status in self._config.retryable_status_codes:
return True
# Retry on timeouts
if result.error and "timeout" in result.error.lower():
return True
# Retry on connection errors
if result.error and any(
kw in result.error.lower()
for kw in ("connection", "connect", "reset", "refused")
):
return True
return False
def _extract_retry_after(self, result: AdapterResult) -> float | None:
"""Extract Retry-After hint from result metadata if present."""
retry_after = result.metadata.get("retry_after")
if retry_after is not None:
try:
return float(retry_after)
except (ValueError, TypeError):
pass
return None
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch with rate-limit coordination, retries, and exponential backoff.
Returns the AdapterResult from the underlying adapter. On retryable
failures, retries up to max_retries times with exponential backoff
and jitter. Rate-limit waits are applied before each attempt.
The returned result's metadata includes retry stats under the
"retry_stats" key.
"""
stats = RetryStats()
last_result: AdapterResult | None = None
for attempt in range(self._config.max_retries + 1):
stats.attempts = attempt + 1
# Rate limit check
wait = await self._check_rate_limit()
if wait > 0:
stats.rate_limited_waits += 1
logger.info(
"Rate limited for %s/%s, waiting %.1fs",
self.source_type(), ticker, wait,
)
stats.total_delay += wait
await asyncio.sleep(wait)
# Execute the fetch
result = await self._adapter.fetch(ticker, config)
last_result = result
# Success — attach stats and return
if result.ok:
result.metadata["retry_stats"] = {
"attempts": stats.attempts,
"total_delay": round(stats.total_delay, 2),
"rate_limited_waits": stats.rate_limited_waits,
}
return result
# Check if retryable
if not self._is_retryable(result):
stats.last_error = result.error
stats.retryable = False
break
stats.retryable = True
stats.last_error = result.error
# Don't sleep after the last attempt
if attempt < self._config.max_retries:
# Respect Retry-After header for 429s
retry_after = self._extract_retry_after(result)
if result.http_status == 429 and retry_after is not None:
delay = min(retry_after, self._config.max_delay)
else:
delay = compute_delay(attempt, self._config)
logger.info(
"Retrying %s/%s (attempt %d/%d) after %.1fs: %s",
self.source_type(), ticker, attempt + 1,
self._config.max_retries + 1, delay, result.error,
)
stats.total_delay += delay
await asyncio.sleep(delay)
# All retries exhausted — return last result with stats
assert last_result is not None
last_result.metadata["retry_stats"] = {
"attempts": stats.attempts,
"total_delay": round(stats.total_delay, 2),
"rate_limited_waits": stats.rate_limited_waits,
"exhausted": True,
"last_error": stats.last_error,
}
return last_result
+321
View File
@@ -0,0 +1,321 @@
"""Web scrape adapter for curated URLs and article pages.
Fetches full article HTML from curated URLs (investor relations pages,
press releases, earnings transcripts, etc.) using BeautifulSoup + requests
with retry adapters, content hashing, boilerplate awareness, and quality scoring.
Inspired by Noctipede crawler patterns: BeautifulSoup + requests with retry
adapters, content hashing, boilerplate stripping, quality scoring.
Requirements: 1.2, 2.5, 3.1, 3.2, 3.3, 3.4
"""
import json
import logging
import time
from datetime import datetime, timezone
from urllib.parse import urlparse
from typing import Any
import httpx
from bs4 import BeautifulSoup
from services.shared.content import content_hash, normalize_url
from .base import AdapterResult, BaseAdapter
logger = logging.getLogger("web_scrape_adapter")
# Default request settings
DEFAULT_TIMEOUT = 30
DEFAULT_USER_AGENT = "StonksOracle/1.0 (+https://stonks-oracle.celestium.life)"
MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB cap
def extract_metadata_from_html(html: str, url: str) -> dict[str, str | None]:
"""Extract title, author, publisher, published date, and links from HTML."""
soup = BeautifulSoup(html, "html.parser")
meta: dict[str, str | None] = {}
# Title: prefer og:title, then <title>
og_title = soup.find("meta", property="og:title")
if og_title and og_title.get("content"):
content = og_title["content"]
meta["title"] = content.strip() if isinstance(content, str) else ""
elif soup.title and soup.title.string:
meta["title"] = soup.title.string.strip()
else:
meta["title"] = ""
# Author
author_tag = soup.find("meta", attrs={"name": "author"})
if author_tag and author_tag.get("content"):
content = author_tag["content"]
meta["author"] = content.strip() if isinstance(content, str) else ""
else:
meta["author"] = ""
# Publisher: og:site_name
site_name = soup.find("meta", property="og:site_name")
if site_name and site_name.get("content"):
content = site_name["content"]
meta["publisher"] = content.strip() if isinstance(content, str) else ""
else:
meta["publisher"] = urlparse(url).hostname or ""
# Published date: article:published_time or datePublished
pub_time = soup.find("meta", property="article:published_time")
if pub_time and pub_time.get("content"):
content = pub_time["content"]
meta["published_at"] = content.strip() if isinstance(content, str) else None
else:
# Try JSON-LD datePublished
for script in soup.find_all("script", type="application/ld+json"):
if script.string and "datePublished" in script.string:
try:
ld = json.loads(script.string)
if isinstance(ld, dict) and "datePublished" in ld:
meta["published_at"] = str(ld["datePublished"])
break
if isinstance(ld, list):
for item in ld:
if isinstance(item, dict) and "datePublished" in item:
meta["published_at"] = str(item["datePublished"])
break
except (json.JSONDecodeError, TypeError):
pass
if "published_at" not in meta:
meta["published_at"] = None
# Canonical URL
canonical = soup.find("link", rel="canonical")
if canonical and canonical.get("href"):
href = canonical["href"]
meta["canonical_url"] = str(href) if href else normalize_url(url)
else:
og_url = soup.find("meta", property="og:url")
if og_url and og_url.get("content"):
content = og_url["content"]
meta["canonical_url"] = str(content) if content else normalize_url(url)
else:
meta["canonical_url"] = normalize_url(url)
# Language
html_tag = soup.find("html")
if html_tag and html_tag.get("lang"):
lang = html_tag["lang"]
meta["language"] = str(lang)[:5] if lang else "en"
else:
meta["language"] = "en"
# Description for summary
desc = soup.find("meta", property="og:description") or soup.find(
"meta", attrs={"name": "description"}
)
if desc and desc.get("content"):
content = desc["content"]
meta["description"] = content.strip() if isinstance(content, str) else ""
else:
meta["description"] = ""
return meta
def extract_body_text(html: str) -> str:
"""Extract main body text from HTML, stripping nav/footer/ads."""
soup = BeautifulSoup(html, "html.parser")
# Remove non-content elements
for tag in soup.find_all(
["script", "style", "nav", "footer", "header", "aside", "iframe", "noscript"]
):
tag.decompose()
# Try to find article body
article = soup.find("article")
if not article:
for div in soup.find_all("div"):
cls = div.get("class", [])
cls_str = " ".join(cls) if isinstance(cls, list) else str(cls) if cls else ""
if any(kw in cls_str for kw in ["article-body", "post-content", "entry-content", "story-body"]):
article = div
break
if article:
text = article.get_text(separator="\n", strip=True)
else:
# Fallback: use body
body = soup.find("body")
text = body.get_text(separator="\n", strip=True) if body else soup.get_text(separator="\n", strip=True)
# Collapse whitespace
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines)
class WebScrapeAdapter(BaseAdapter):
"""Adapter for fetching curated web pages and article URLs.
Config options (from source config):
urls: List of URLs to scrape for this company
url: Single URL to scrape (alternative to urls)
timeout: Request timeout in seconds (default 30)
user_agent: Custom user agent string
follow_links: Whether to follow article links from index pages (default False)
max_pages: Max pages to fetch per cycle (default 5)
"""
def __init__(self) -> None:
pass
def source_type(self) -> str:
return "web_scrape"
def bucket_name(self) -> str:
"""Web scrape artifacts go to the news raw bucket."""
return "stonks-raw-news"
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch HTML from curated URLs for a given ticker.
Supports both single URL and multi-URL configs. Each URL is fetched,
HTML is preserved as raw payload, and metadata is extracted.
"""
urls = config.get("urls", [])
if not urls and config.get("url"):
urls = [config["url"]]
if not urls:
return self._error_result(ticker, "No URLs configured for web_scrape source", 0)
timeout = config.get("timeout", DEFAULT_TIMEOUT)
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
max_pages = min(config.get("max_pages", 5), 20)
items: list[dict[str, Any]] = []
all_raw: list[bytes] = []
total_elapsed = 0.0
errors: list[str] = []
async with httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
headers={"User-Agent": user_agent},
) as client:
for url in urls[:max_pages]:
t0 = time.monotonic()
try:
resp = await client.get(url)
elapsed_ms = (time.monotonic() - t0) * 1000
total_elapsed += elapsed_ms
resp.raise_for_status()
# Content length guard
if len(resp.content) > MAX_CONTENT_LENGTH:
errors.append(f"Content too large for {url}: {len(resp.content)} bytes")
continue
html = resp.text
raw_bytes = resp.content
all_raw.append(raw_bytes)
item_content_hash = content_hash(raw_bytes)
meta = extract_metadata_from_html(html, url)
body_text = extract_body_text(html)
item: dict[str, Any] = {
"url": url,
"canonical_url": meta.get("canonical_url", normalize_url(url)),
"title": meta.get("title", ""),
"author": meta.get("author", ""),
"publisher": meta.get("publisher", ""),
"published_at": meta.get("published_at"),
"language": meta.get("language", "en"),
"description": meta.get("description", ""),
"content_hash": item_content_hash,
"body_text": body_text,
"body_length": len(body_text),
"html_length": len(html),
"http_status": resp.status_code,
"response_time_ms": round(elapsed_ms, 1),
}
items.append(item)
except httpx.HTTPStatusError as e:
elapsed_ms = (time.monotonic() - t0) * 1000
total_elapsed += elapsed_ms
status = e.response.status_code if e.response else None
errors.append(f"HTTP {status} for {url}: {e}")
logger.warning("Scrape HTTP error for %s/%s: %s", ticker, url, e)
except httpx.TimeoutException as e:
elapsed_ms = (time.monotonic() - t0) * 1000
total_elapsed += elapsed_ms
errors.append(f"Timeout for {url}: {e}")
logger.warning("Scrape timeout for %s/%s: %s", ticker, url, e)
except Exception as e:
elapsed_ms = (time.monotonic() - t0) * 1000
total_elapsed += elapsed_ms
errors.append(f"Error for {url}: {e}")
logger.warning("Scrape error for %s/%s: %s", ticker, url, e)
if not items:
error_msg = "; ".join(errors) if errors else "No pages fetched"
return self._error_result(ticker, error_msg, total_elapsed)
# Combine all raw payloads into a single artifact
combined_raw = json.dumps({
"ticker": ticker,
"fetched_at": datetime.now(timezone.utc).isoformat(),
"pages": [
{
"url": item["url"],
"content_hash": item["content_hash"],
"html_length": item["html_length"],
"body_length": item["body_length"],
}
for item in items
],
"errors": errors,
}).encode("utf-8")
combined_hash = content_hash(
b"".join(item["content_hash"].encode() for item in items)
)
return AdapterResult(
source_type="web_scrape",
ticker=ticker,
items=items,
raw_payload=combined_raw,
content_hash=combined_hash,
fetched_at=datetime.now(timezone.utc),
http_status=200,
response_time_ms=round(total_elapsed, 1),
metadata={
"provider": "web_scrape",
"pages_fetched": len(items),
"pages_failed": len(errors),
"errors": errors,
},
)
def _error_result(
self,
ticker: str,
error: str,
elapsed_ms: float,
) -> AdapterResult:
"""Build an error AdapterResult for scrape fetches."""
return AdapterResult(
source_type="web_scrape",
ticker=ticker,
items=[],
raw_payload=b"",
content_hash="",
fetched_at=datetime.now(timezone.utc),
error=error,
http_status=None,
response_time_ms=round(elapsed_ms, 1),
metadata={"provider": "web_scrape"},
)
+169
View File
@@ -0,0 +1,169 @@
"""Contradiction detection and disagreement representation.
Analyses weighted signals to detect and represent disagreement explicitly,
rather than collapsing contradictory evidence into a single unsupported
conclusion.
Requirements: 6.4, 6.5
"""
from __future__ import annotations
from dataclasses import dataclass
from services.aggregation.scoring import WeightedSignal
from services.shared.schemas import DisagreementDetail
@dataclass
class CatalystEntry:
"""Lightweight carrier for per-document catalyst info needed by
contradiction detection. Avoids importing ImpactRow and creating
a circular dependency with worker.py."""
document_id: str
catalyst_type: str
@dataclass
class ContradictionResult:
"""Full contradiction analysis output."""
score: float # 0-1, same semantics as existing compute_contradiction_score
details: list[DisagreementDetail]
def detect_contradictions(
signals: list[WeightedSignal],
catalyst_entries: list[CatalystEntry] | None = None,
) -> ContradictionResult:
"""Run contradiction detection across multiple dimensions.
Analyses:
1. Sentiment disagreement — the core positive-vs-negative split
2. Catalyst disagreement — same catalyst type with opposing sentiment
Returns a ContradictionResult with an overall score and per-dimension
disagreement details.
"""
details: list[DisagreementDetail] = []
sentiment_detail = _detect_sentiment_disagreement(signals)
if sentiment_detail is not None:
details.append(sentiment_detail)
if catalyst_entries:
catalyst_details = _detect_catalyst_disagreement(signals, catalyst_entries)
details.extend(catalyst_details)
score = _compute_overall_score(signals)
return ContradictionResult(score=score, details=details)
def _compute_overall_score(signals: list[WeightedSignal]) -> float:
"""Minority/majority weight ratio — backward-compatible formula."""
if not signals:
return 0.0
pos_weight = 0.0
neg_weight = 0.0
for sig in signals:
w = sig.weight.combined * sig.impact_score
if sig.sentiment_value > 0:
pos_weight += w
elif sig.sentiment_value < 0:
neg_weight += w
total = pos_weight + neg_weight
if total == 0.0:
return 0.0
minority = min(pos_weight, neg_weight)
return round(minority / total, 4)
def _detect_sentiment_disagreement(
signals: list[WeightedSignal],
) -> DisagreementDetail | None:
"""Detect when both positive and negative sentiment signals exist."""
pos_ids: list[str] = []
neg_ids: list[str] = []
pos_weight = 0.0
neg_weight = 0.0
for sig in signals:
w = sig.weight.combined * sig.impact_score
if w <= 0:
continue
if sig.sentiment_value > 0:
pos_ids.append(sig.document_id)
pos_weight += w
elif sig.sentiment_value < 0:
neg_ids.append(sig.document_id)
neg_weight += w
if not pos_ids or not neg_ids:
return None
total = pos_weight + neg_weight
minority_pct = min(pos_weight, neg_weight) / total if total > 0 else 0.0
return DisagreementDetail(
dimension="sentiment",
positive_doc_ids=pos_ids,
negative_doc_ids=neg_ids,
positive_weight=round(pos_weight, 4),
negative_weight=round(neg_weight, 4),
description=(
f"Sentiment split: {len(pos_ids)} positive vs {len(neg_ids)} negative signals "
f"(minority weight ratio {minority_pct:.0%})"
),
)
def _detect_catalyst_disagreement(
signals: list[WeightedSignal],
catalyst_entries: list[CatalystEntry],
) -> list[DisagreementDetail]:
"""Detect when the same catalyst type has both positive and negative signals."""
# Build lookup: document_id → (sentiment_value, combined_weight)
sig_lookup: dict[str, tuple[float, float]] = {}
for sig in signals:
w = sig.weight.combined * sig.impact_score
if w > 0:
sig_lookup[sig.document_id] = (sig.sentiment_value, w)
# Group by catalyst type
from collections import defaultdict
catalyst_groups: dict[str, list[tuple[str, float, float]]] = defaultdict(list)
for entry in catalyst_entries:
if entry.document_id in sig_lookup:
sent_val, weight = sig_lookup[entry.document_id]
if sent_val != 0.0:
catalyst_groups[entry.catalyst_type].append(
(entry.document_id, sent_val, weight)
)
details: list[DisagreementDetail] = []
for catalyst, entries in catalyst_groups.items():
pos_ids = [doc_id for doc_id, sv, _ in entries if sv > 0]
neg_ids = [doc_id for doc_id, sv, _ in entries if sv < 0]
if not pos_ids or not neg_ids:
continue
pos_w = sum(w for _, sv, w in entries if sv > 0)
neg_w = sum(w for _, sv, w in entries if sv < 0)
details.append(DisagreementDetail(
dimension=f"catalyst:{catalyst}",
positive_doc_ids=pos_ids,
negative_doc_ids=neg_ids,
positive_weight=round(pos_w, 4),
negative_weight=round(neg_w, 4),
description=(
f"Catalyst '{catalyst}' has {len(pos_ids)} positive and "
f"{len(neg_ids)} negative signals"
),
))
return details
+141
View File
@@ -0,0 +1,141 @@
"""Evidence ranking for supporting and opposing documents.
Ranks document signals by a composite score that considers multiple
factors beyond raw weight, producing explainable evidence lists for
trend summaries.
Requirements: 6.5
"""
from __future__ import annotations
from dataclasses import dataclass
from services.aggregation.scoring import WeightedSignal
@dataclass(frozen=True)
class EvidenceRankConfig:
"""Weights for the composite evidence ranking score."""
# How much the combined signal weight matters (recency * credibility * novelty * market)
weight_factor: float = 0.40
# How much the document's impact score matters
impact_factor: float = 0.30
# How much recency alone matters (favours fresh evidence in the ranking)
recency_factor: float = 0.20
# How much extraction confidence matters
confidence_factor: float = 0.10
# Maximum evidence refs per side (supporting / opposing)
max_refs: int = 10
DEFAULT_RANK_CONFIG = EvidenceRankConfig()
@dataclass
class RankedEvidence:
"""A document with its composite ranking score and breakdown."""
document_id: str
rank_score: float
weight_component: float
impact_component: float
recency_component: float
confidence_component: float
sentiment_value: float # +1 / -1 / 0
def compute_evidence_rank(
signal: WeightedSignal,
config: EvidenceRankConfig = DEFAULT_RANK_CONFIG,
) -> RankedEvidence:
"""Compute a composite ranking score for a single signal.
The score blends:
- combined signal weight (captures recency decay, credibility, novelty, market ctx)
- raw impact score
- recency weight alone (extra boost for freshness in the ranking)
- extraction confidence (via the credibility component of the weight)
All components are in [0, 1] so the composite is bounded by the sum
of the factor weights.
"""
w = signal.weight
weight_component = w.combined * config.weight_factor
impact_component = signal.impact_score * config.impact_factor
recency_component = w.recency * config.recency_factor
confidence_component = w.credibility * config.confidence_factor
rank_score = weight_component + impact_component + recency_component + confidence_component
return RankedEvidence(
document_id=signal.document_id,
rank_score=round(rank_score, 6),
weight_component=round(weight_component, 6),
impact_component=round(impact_component, 6),
recency_component=round(recency_component, 6),
confidence_component=round(confidence_component, 6),
sentiment_value=signal.sentiment_value,
)
def rank_evidence(
signals: list[WeightedSignal],
config: EvidenceRankConfig = DEFAULT_RANK_CONFIG,
) -> tuple[list[str], list[str]]:
"""Rank signals into top supporting and opposing document ID lists.
Supporting = positive sentiment, Opposing = negative sentiment.
Neutral/mixed signals are excluded.
Returns (supporting_ids, opposing_ids) each capped at config.max_refs.
"""
supporting: list[RankedEvidence] = []
opposing: list[RankedEvidence] = []
for sig in signals:
if sig.sentiment_value == 0.0:
continue
ranked = compute_evidence_rank(sig, config)
if sig.sentiment_value > 0:
supporting.append(ranked)
else:
opposing.append(ranked)
supporting.sort(key=lambda r: r.rank_score, reverse=True)
opposing.sort(key=lambda r: r.rank_score, reverse=True)
return (
[r.document_id for r in supporting[: config.max_refs]],
[r.document_id for r in opposing[: config.max_refs]],
)
def rank_evidence_detailed(
signals: list[WeightedSignal],
config: EvidenceRankConfig = DEFAULT_RANK_CONFIG,
) -> tuple[list[RankedEvidence], list[RankedEvidence]]:
"""Like rank_evidence but returns full RankedEvidence objects.
Useful when callers need the score breakdown for explainability.
"""
supporting: list[RankedEvidence] = []
opposing: list[RankedEvidence] = []
for sig in signals:
if sig.sentiment_value == 0.0:
continue
ranked = compute_evidence_rank(sig, config)
if sig.sentiment_value > 0:
supporting.append(ranked)
else:
opposing.append(ranked)
supporting.sort(key=lambda r: r.rank_score, reverse=True)
opposing.sort(key=lambda r: r.rank_score, reverse=True)
return (
supporting[: config.max_refs],
opposing[: config.max_refs],
)
+57
View File
@@ -0,0 +1,57 @@
"""Aggregation worker entrypoint - polls Redis for aggregation jobs."""
from __future__ import annotations
import asyncio
import json
import logging
import asyncpg
from services.aggregation.worker import aggregate_company
from services.shared.config import load_config
from services.shared.logging import setup_logging
from services.shared.redis_keys import QUEUE_AGGREGATION, queue_key
logger = logging.getLogger("aggregation_main")
async def main() -> None:
config = load_config()
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_AGGREGATION)
logger.info("Aggregation worker started, polling %s", queue)
try:
while True:
raw = await redis_client.lpop(queue)
if raw is None:
await asyncio.sleep(1)
continue
payload = raw
job = json.loads(payload)
ticker = job.get("ticker", "")
logger.info("Processing aggregation job for %s", ticker)
try:
summaries = await aggregate_company(pool, ticker)
logger.info(
"Aggregation complete for %s: %d windows",
ticker, len(summaries),
)
except Exception:
logger.exception("Aggregation failed for %s", ticker)
finally:
await pool.close()
await redis_client.close()
if __name__ == "__main__":
asyncio.run(main())
+150
View File
@@ -0,0 +1,150 @@
"""Market context feature computation for aggregation windows.
Fetches recent market snapshots from PostgreSQL and computes context
features (price change, volume trend, volatility) that enrich trend
summaries and modulate signal weighting.
Requirements: 6.1, 6.2
"""
from __future__ import annotations
import math
from datetime import datetime, timedelta, timezone
from typing import Any
import asyncpg
from services.shared.schemas import MarketContext, TrendWindow
# Map TrendWindow values to lookback durations in days.
WINDOW_LOOKBACK_DAYS: dict[str, int] = {
TrendWindow.INTRADAY.value: 1,
TrendWindow.ONE_DAY.value: 2,
TrendWindow.SEVEN_DAY.value: 8,
TrendWindow.THIRTY_DAY.value: 35,
TrendWindow.NINETY_DAY.value: 95,
}
async def fetch_market_context(
pool: asyncpg.Pool,
ticker: str,
window: str,
reference_time: datetime | None = None,
) -> MarketContext:
"""Build a MarketContext for *ticker* over the given trend *window*.
Queries the ``market_snapshots`` table for recent bars and computes:
- price_change_pct: (last_close - first_close) / first_close
- avg_volume: mean volume across bars
- volume_change_pct: second-half avg volume vs first-half avg volume
- volatility: std-dev of close prices
- latest_close / latest_bar_at
Returns a MarketContext with ``bars_available == 0`` when no data exists.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
lookback_days = WINDOW_LOOKBACK_DAYS.get(window, 8)
start = reference_time - timedelta(days=lookback_days)
rows = await pool.fetch(
"""
SELECT data, captured_at
FROM market_snapshots
WHERE ticker = $1
AND captured_at >= $2
AND captured_at <= $3
ORDER BY captured_at ASC
""",
ticker,
start,
reference_time,
)
if not rows:
return MarketContext(ticker=ticker)
bars = _extract_bars(rows)
if not bars:
return MarketContext(ticker=ticker)
return _compute_context(ticker, bars)
def _extract_bars(rows: list[Any]) -> list[dict[str, Any]]:
"""Extract OHLCV bar dicts from market_snapshot rows.
The ``data`` column is JSONB. Polygon prev-day bars store fields like
``o``, ``h``, ``l``, ``c``, ``v``, ``t``. We normalise to a common
dict with ``close``, ``volume``, ``captured_at``.
"""
bars: list[dict[str, Any]] = []
for row in rows:
data = row["data"]
if isinstance(data, str):
import json
data = json.loads(data)
# Polygon-style single bar or list of bars
items = data if isinstance(data, list) else [data]
for item in items:
close = item.get("c") or item.get("close")
volume = item.get("v") or item.get("volume")
if close is not None:
bars.append({
"close": float(close),
"volume": float(volume) if volume is not None else 0.0,
"captured_at": row["captured_at"],
})
return bars
def _compute_context(ticker: str, bars: list[dict[str, Any]]) -> MarketContext:
"""Derive market context features from a sorted list of bar dicts."""
closes = [b["close"] for b in bars]
volumes = [b["volume"] for b in bars]
first_close = closes[0]
last_close = closes[-1]
price_change_pct = (
((last_close - first_close) / first_close * 100.0)
if first_close != 0
else 0.0
)
avg_volume = sum(volumes) / len(volumes) if volumes else 0.0
# Volume trend: compare second half to first half
mid = len(volumes) // 2
if mid > 0:
first_half_avg = sum(volumes[:mid]) / mid
second_half_avg = sum(volumes[mid:]) / len(volumes[mid:])
volume_change_pct = (
((second_half_avg - first_half_avg) / first_half_avg * 100.0)
if first_half_avg > 0
else 0.0
)
else:
volume_change_pct = 0.0
# Volatility: std dev of closes
if len(closes) > 1:
mean_close = sum(closes) / len(closes)
variance = sum((c - mean_close) ** 2 for c in closes) / len(closes)
volatility = math.sqrt(variance)
else:
volatility = 0.0
return MarketContext(
ticker=ticker,
price_change_pct=round(price_change_pct, 4),
avg_volume=round(avg_volume, 2),
volume_change_pct=round(volume_change_pct, 4),
volatility=round(volatility, 6),
latest_close=last_close,
latest_bar_at=bars[-1]["captured_at"],
bars_available=len(bars),
)
+439
View File
@@ -0,0 +1,439 @@
"""Sector and market-level rollup aggregation.
Aggregates company-level trend summaries into sector and market-level
summaries, enabling top-down views of sentiment and risk across the
portfolio.
Requirements: 6.3, 6.4, 6.5
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
import asyncpg
from services.shared.schemas import (
DisagreementDetail,
TrendDirection,
TrendSummary,
TrendWindow,
)
logger = logging.getLogger(__name__)
@dataclass
class CompanyTrendRow:
"""A company-level trend summary fetched from the DB for rollup."""
entity_id: str # ticker
sector: str
window: str
trend_direction: str
trend_strength: float
confidence: float
contradiction_score: float
dominant_catalysts: list[str]
material_risks: list[str]
top_supporting_evidence: list[str]
top_opposing_evidence: list[str]
# ---------------------------------------------------------------------------
# Fetch latest company trends for a given window
# ---------------------------------------------------------------------------
_LATEST_COMPANY_TRENDS_QUERY = """
SELECT DISTINCT ON (tw.entity_id)
tw.entity_id,
c.sector,
tw.window,
tw.trend_direction,
tw.trend_strength,
tw.confidence,
tw.contradiction_score,
tw.dominant_catalysts,
tw.material_risks,
tw.top_supporting_evidence,
tw.top_opposing_evidence
FROM trend_windows tw
JOIN companies c ON c.ticker = tw.entity_id AND c.active = TRUE
WHERE tw.entity_type = 'company'
AND tw.window = $1
AND tw.generated_at >= $2
ORDER BY tw.entity_id, tw.generated_at DESC
"""
def _parse_jsonb_list(val: object) -> list[str]:
"""Safely parse a JSONB column that should be a list of strings."""
if isinstance(val, list):
return [str(v) for v in val]
if isinstance(val, str):
parsed = json.loads(val)
if isinstance(parsed, list):
return [str(v) for v in parsed]
return []
def _parse_company_trend_row(row: object) -> CompanyTrendRow:
"""Convert an asyncpg Record to a CompanyTrendRow."""
# asyncpg Records support dict() but aren't typed; use getattr-style access
get = getattr(row, "__getitem__", None)
if get is None:
raise TypeError(f"Expected a mapping-like row, got {type(row)}")
def _str(key: str, default: str = "") -> str:
val = get(key)
return str(val) if val is not None else default
def _float(key: str) -> float:
val = get(key)
return float(val) if val is not None else 0.0
return CompanyTrendRow(
entity_id=_str("entity_id"),
sector=_str("sector", "Unknown") or "Unknown",
window=_str("window"),
trend_direction=_str("trend_direction"),
trend_strength=_float("trend_strength"),
confidence=_float("confidence"),
contradiction_score=_float("contradiction_score"),
dominant_catalysts=_parse_jsonb_list(get("dominant_catalysts")),
material_risks=_parse_jsonb_list(get("material_risks")),
top_supporting_evidence=_parse_jsonb_list(get("top_supporting_evidence")),
top_opposing_evidence=_parse_jsonb_list(get("top_opposing_evidence")),
)
async def fetch_latest_company_trends(
pool: asyncpg.Pool,
window: str,
since: datetime,
) -> list[CompanyTrendRow]:
"""Fetch the most recent company-level trend for each ticker in a window."""
rows = await pool.fetch(_LATEST_COMPANY_TRENDS_QUERY, window, since)
return [_parse_company_trend_row(r) for r in rows]
# ---------------------------------------------------------------------------
# Pure rollup logic
# ---------------------------------------------------------------------------
# Direction mapping for numeric aggregation
_DIRECTION_VALUES = {
TrendDirection.BULLISH.value: 1.0,
TrendDirection.BEARISH.value: -1.0,
TrendDirection.MIXED.value: 0.0,
TrendDirection.NEUTRAL.value: 0.0,
}
BULLISH_THRESHOLD = 0.15
BEARISH_THRESHOLD = -0.15
def rollup_trends(
trends: list[CompanyTrendRow],
entity_type: str,
entity_id: str,
window: str,
reference_time: datetime,
) -> TrendSummary:
"""Aggregate a list of company-level trends into a single rollup summary.
Each company trend is weighted by its confidence to produce a
confidence-weighted average of direction, strength, and contradiction.
"""
if not trends:
return TrendSummary(
entity_type=entity_type,
entity_id=entity_id,
window=TrendWindow(window),
trend_direction=TrendDirection.NEUTRAL,
trend_strength=0.0,
confidence=0.0,
generated_at=reference_time,
)
total_weight = 0.0
weighted_direction = 0.0
weighted_strength = 0.0
weighted_contradiction = 0.0
catalyst_weights: dict[str, float] = {}
risk_set: dict[str, float] = {}
all_supporting: list[str] = []
all_opposing: list[str] = []
for t in trends:
w = t.confidence
total_weight += w
dir_val = _DIRECTION_VALUES.get(t.trend_direction, 0.0)
weighted_direction += w * dir_val
weighted_strength += w * t.trend_strength
weighted_contradiction += w * t.contradiction_score
for cat in t.dominant_catalysts:
catalyst_weights[cat] = catalyst_weights.get(cat, 0.0) + w
for risk in t.material_risks:
norm = risk.strip().lower()
if norm not in risk_set:
risk_set[norm] = w
else:
risk_set[norm] = max(risk_set[norm], w)
all_supporting.extend(t.top_supporting_evidence)
all_opposing.extend(t.top_opposing_evidence)
if total_weight == 0.0:
return TrendSummary(
entity_type=entity_type,
entity_id=entity_id,
window=TrendWindow(window),
trend_direction=TrendDirection.NEUTRAL,
trend_strength=0.0,
confidence=0.0,
generated_at=reference_time,
)
avg_direction = weighted_direction / total_weight
avg_strength = weighted_strength / total_weight
avg_contradiction = weighted_contradiction / total_weight
avg_confidence = total_weight / len(trends)
# Derive direction
direction = _derive_rollup_direction(avg_direction, avg_contradiction)
# Top catalysts
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
catalysts = [c for c, _ in sorted_catalysts[:5]]
# Top risks (deduplicated, by weight)
sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True)
risks = [r for r, _ in sorted_risks[:5]]
# Disagreement details
disagreement = _build_rollup_disagreement(trends, entity_id)
return TrendSummary(
entity_type=entity_type,
entity_id=entity_id,
window=TrendWindow(window),
trend_direction=direction,
trend_strength=round(min(abs(avg_strength), 1.0), 4),
confidence=round(max(0.0, min(avg_confidence, 1.0)), 4),
top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10],
top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10],
dominant_catalysts=catalysts,
material_risks=risks,
contradiction_score=round(max(0.0, min(avg_contradiction, 1.0)), 4),
disagreement_details=disagreement,
generated_at=reference_time,
)
def _derive_rollup_direction(
avg_direction: float,
avg_contradiction: float,
) -> TrendDirection:
"""Map averaged direction value to a TrendDirection."""
if avg_contradiction > 0.10 and abs(avg_direction) < 0.3:
return TrendDirection.MIXED
if avg_direction >= BULLISH_THRESHOLD:
return TrendDirection.BULLISH
if avg_direction <= BEARISH_THRESHOLD:
return TrendDirection.BEARISH
return TrendDirection.NEUTRAL
def _build_rollup_disagreement(
trends: list[CompanyTrendRow],
entity_id: str,
) -> list[DisagreementDetail]:
"""Build disagreement details showing which companies are bullish vs bearish."""
bullish_ids: list[str] = []
bearish_ids: list[str] = []
bullish_weight = 0.0
bearish_weight = 0.0
for t in trends:
if t.trend_direction == TrendDirection.BULLISH.value:
bullish_ids.append(t.entity_id)
bullish_weight += t.confidence
elif t.trend_direction == TrendDirection.BEARISH.value:
bearish_ids.append(t.entity_id)
bearish_weight += t.confidence
if not bullish_ids or not bearish_ids:
return []
return [
DisagreementDetail(
dimension="company_direction",
positive_doc_ids=bullish_ids,
negative_doc_ids=bearish_ids,
positive_weight=round(bullish_weight, 4),
negative_weight=round(bearish_weight, 4),
description=(
f"{entity_id}: {len(bullish_ids)} bullish vs "
f"{len(bearish_ids)} bearish companies"
),
)
]
# ---------------------------------------------------------------------------
# Persist rollup (reuses the same trend_windows table)
# ---------------------------------------------------------------------------
_UPSERT_TREND = """
INSERT INTO trend_windows (
entity_type, entity_id, window, trend_direction, trend_strength,
confidence, top_supporting_evidence, top_opposing_evidence,
dominant_catalysts, material_risks, contradiction_score,
disagreement_details, market_context, generated_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7::jsonb, $8::jsonb,
$9::jsonb, $10::jsonb, $11,
$12::jsonb, $13::jsonb, $14
)
RETURNING id
"""
async def persist_rollup(
pool: asyncpg.Pool,
summary: TrendSummary,
) -> str:
"""Insert a rollup trend summary and return its UUID."""
row = await pool.fetchrow(
_UPSERT_TREND,
summary.entity_type,
summary.entity_id,
summary.window.value,
summary.trend_direction.value,
summary.trend_strength,
summary.confidence,
json.dumps(summary.top_supporting_evidence),
json.dumps(summary.top_opposing_evidence),
json.dumps(summary.dominant_catalysts),
json.dumps(summary.material_risks),
summary.contradiction_score,
json.dumps([d.model_dump() for d in summary.disagreement_details]),
json.dumps({}),
summary.generated_at,
)
return str(row["id"]) # type: ignore[index]
# ---------------------------------------------------------------------------
# High-level rollup entry points
# ---------------------------------------------------------------------------
async def aggregate_sector(
pool: asyncpg.Pool,
sector: str,
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
) -> TrendSummary:
"""Compute and persist a sector-level rollup for one window.
Fetches the latest company trends, filters to the given sector,
and rolls them up into a single sector summary.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
if since is None:
since = reference_time - _window_lookback(window)
all_trends = await fetch_latest_company_trends(pool, window, since)
sector_trends = [t for t in all_trends if t.sector == sector]
summary = rollup_trends(sector_trends, "sector", sector, window, reference_time)
if sector_trends:
rollup_id = await persist_rollup(pool, summary)
logger.info(
"Persisted sector rollup %s for %s/%s: direction=%s strength=%.3f companies=%d",
rollup_id, sector, window, summary.trend_direction.value,
summary.trend_strength, len(sector_trends),
)
return summary
async def aggregate_market(
pool: asyncpg.Pool,
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
) -> TrendSummary:
"""Compute and persist a market-wide rollup for one window.
Aggregates all company trends regardless of sector.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
if since is None:
since = reference_time - _window_lookback(window)
all_trends = await fetch_latest_company_trends(pool, window, since)
summary = rollup_trends(all_trends, "market", "all", window, reference_time)
if all_trends:
rollup_id = await persist_rollup(pool, summary)
logger.info(
"Persisted market rollup %s for %s: direction=%s strength=%.3f companies=%d",
rollup_id, window, summary.trend_direction.value,
summary.trend_strength, len(all_trends),
)
return summary
async def aggregate_all_sectors(
pool: asyncpg.Pool,
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
) -> list[TrendSummary]:
"""Compute sector rollups for every sector that has company trends."""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
if since is None:
since = reference_time - _window_lookback(window)
all_trends = await fetch_latest_company_trends(pool, window, since)
# Group by sector
sectors: dict[str, list[CompanyTrendRow]] = {}
for t in all_trends:
sectors.setdefault(t.sector, []).append(t)
summaries: list[TrendSummary] = []
for sector, trends in sectors.items():
summary = rollup_trends(trends, "sector", sector, window, reference_time)
if trends:
_id = await persist_rollup(pool, summary)
summaries.append(summary)
return summaries
def _window_lookback(window: str) -> timedelta:
"""Return a reasonable lookback for finding recent company trends."""
mapping = {
TrendWindow.INTRADAY.value: timedelta(hours=24),
TrendWindow.ONE_DAY.value: timedelta(days=2),
TrendWindow.SEVEN_DAY.value: timedelta(days=8),
TrendWindow.THIRTY_DAY.value: timedelta(days=35),
TrendWindow.NINETY_DAY.value: timedelta(days=95),
}
return mapping.get(window, timedelta(days=8))
+285
View File
@@ -0,0 +1,285 @@
"""Recency decay, source credibility weighting, and market context
integration for aggregation.
Provides scoring functions used by the aggregation engine to weight
document intelligence signals when computing trend summaries.
Requirements: 6.1, 6.2, 6.5
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from datetime import datetime, timezone
from services.shared.schemas import MarketContext
@dataclass(frozen=True)
class ScoringConfig:
"""Tunable parameters for signal scoring."""
# Recency decay: exponential half-life in hours per window.
# After one half-life, a document's recency weight drops to 0.5.
half_life_hours: dict[str, float] = field(default_factory=lambda: {
"intraday": 2.0,
"1d": 12.0,
"7d": 72.0,
"30d": 240.0,
"90d": 720.0,
})
# Minimum recency weight — prevents very old docs from being zeroed out
# entirely so they can still contribute trace-level signal.
min_recency_weight: float = 0.01
# Source credibility bounds — credibility scores outside this range
# are clamped before weighting.
credibility_floor: float = 0.1
credibility_ceiling: float = 1.0
# Exponent applied to credibility score. >1 penalises low-credibility
# sources more aggressively; <1 flattens the curve.
credibility_exponent: float = 1.0
# Novelty bonus: multiplier range applied on top of base weight.
# A novelty_score of 1.0 gets the full bonus; 0.0 gets none.
novelty_bonus_max: float = 0.25
# Confidence floor — documents below this extraction confidence
# receive zero weight (they are too unreliable to aggregate).
confidence_floor: float = 0.2
# Market context modulation ---
# When volatility exceeds this threshold (in price units), recency
# signals are amplified because fast-moving markets make fresh data
# more important.
volatility_recency_boost_threshold: float = 1.0
volatility_recency_boost_max: float = 0.30 # max extra multiplier
# When volume surges above this % change, signals get a small boost
# because high-volume moves carry more conviction.
volume_surge_threshold_pct: float = 50.0
volume_surge_boost: float = 0.15
# Singleton default config
DEFAULT_CONFIG = ScoringConfig()
# ---------------------------------------------------------------------------
# Recency decay
# ---------------------------------------------------------------------------
def recency_weight(
published_at: datetime,
reference_time: datetime,
window: str,
config: ScoringConfig = DEFAULT_CONFIG,
) -> float:
"""Compute an exponential recency decay weight for a document.
Uses the formula: w = 2^(-age_hours / half_life)
Args:
published_at: When the document was published (tz-aware).
reference_time: The "now" anchor for the aggregation window (tz-aware).
window: One of the TrendWindow values (e.g. "7d").
config: Scoring parameters.
Returns:
A weight in [config.min_recency_weight, 1.0].
"""
# Ensure both are tz-aware; treat naive as UTC.
if published_at.tzinfo is None:
published_at = published_at.replace(tzinfo=timezone.utc)
if reference_time.tzinfo is None:
reference_time = reference_time.replace(tzinfo=timezone.utc)
age_seconds = (reference_time - published_at).total_seconds()
if age_seconds <= 0:
return 1.0
age_hours = age_seconds / 3600.0
half_life = config.half_life_hours.get(window, 72.0)
weight = math.pow(2.0, -age_hours / half_life)
return max(weight, config.min_recency_weight)
# ---------------------------------------------------------------------------
# Source credibility weighting
# ---------------------------------------------------------------------------
def credibility_weight(
source_credibility: float,
config: ScoringConfig = DEFAULT_CONFIG,
) -> float:
"""Compute a weight from a source's credibility score.
The raw credibility (0-1) is clamped to [floor, ceiling] then raised
to ``credibility_exponent``.
Args:
source_credibility: The credibility score from the source or
document intelligence record (0-1).
config: Scoring parameters.
Returns:
A weight in [floor^exp, ceiling^exp].
"""
clamped = max(config.credibility_floor, min(source_credibility, config.credibility_ceiling))
return math.pow(clamped, config.credibility_exponent)
# ---------------------------------------------------------------------------
# Market context adjustment
# ---------------------------------------------------------------------------
def market_context_multiplier(
market_ctx: MarketContext | None,
config: ScoringConfig = DEFAULT_CONFIG,
) -> float:
"""Compute a multiplicative adjustment from market context features.
Returns a value >= 1.0 that amplifies signal weights when market
conditions suggest heightened importance (high volatility or volume
surges). Returns 1.0 when no market context is available.
"""
if market_ctx is None or not market_ctx.has_data:
return 1.0
boost = 0.0
# Volatility boost — more volatile markets make recent signals more valuable
if market_ctx.volatility is not None and market_ctx.volatility > config.volatility_recency_boost_threshold:
excess = market_ctx.volatility - config.volatility_recency_boost_threshold
# Logarithmic scaling so extreme volatility doesn't blow up the weight
boost += min(
math.log1p(excess) * 0.15,
config.volatility_recency_boost_max,
)
# Volume surge boost
if market_ctx.volume_change_pct is not None and market_ctx.volume_change_pct > config.volume_surge_threshold_pct:
boost += config.volume_surge_boost
return 1.0 + boost
# ---------------------------------------------------------------------------
# Combined document signal weight
# ---------------------------------------------------------------------------
@dataclass
class SignalWeight:
"""Breakdown of a document's aggregation weight."""
recency: float
credibility: float
novelty_bonus: float
confidence_gate: float # 0.0 or 1.0
market_ctx_multiplier: float # >= 1.0
combined: float
def compute_signal_weight(
published_at: datetime,
reference_time: datetime,
window: str,
source_credibility: float,
novelty_score: float = 0.5,
extraction_confidence: float = 0.5,
market_ctx: MarketContext | None = None,
config: ScoringConfig = DEFAULT_CONFIG,
) -> SignalWeight:
"""Compute the combined aggregation weight for a single document signal.
The formula is:
combined = confidence_gate * recency * credibility
* (1 + novelty_bonus) * market_ctx_multiplier
where novelty_bonus = novelty_score * config.novelty_bonus_max
and market_ctx_multiplier >= 1.0 based on volatility/volume features.
Documents with extraction_confidence below config.confidence_floor
receive a combined weight of 0.0 (gated out).
Args:
published_at: Document publication time.
reference_time: Aggregation anchor time.
window: Trend window identifier.
source_credibility: Source credibility score (0-1).
novelty_score: Document novelty score (0-1).
extraction_confidence: Extraction confidence from the model (0-1).
market_ctx: Optional market context features for the symbol.
config: Scoring parameters.
Returns:
A ``SignalWeight`` with the component breakdown and combined score.
"""
# Confidence gate
gate = 1.0 if extraction_confidence >= config.confidence_floor else 0.0
rec = recency_weight(published_at, reference_time, window, config)
cred = credibility_weight(source_credibility, config)
bonus = novelty_score * config.novelty_bonus_max
mkt_mult = market_context_multiplier(market_ctx, config)
combined = gate * rec * cred * (1.0 + bonus) * mkt_mult
return SignalWeight(
recency=rec,
credibility=cred,
novelty_bonus=bonus,
confidence_gate=gate,
market_ctx_multiplier=mkt_mult,
combined=combined,
)
# ---------------------------------------------------------------------------
# Batch helpers
# ---------------------------------------------------------------------------
@dataclass
class WeightedSignal:
"""A document intelligence reference paired with its computed weight."""
document_id: str
weight: SignalWeight
sentiment_value: float # numeric sentiment: +1 positive, -1 negative, 0 neutral/mixed
impact_score: float
def sentiment_to_numeric(sentiment: str) -> float:
"""Map a sentiment label to a signed numeric value."""
mapping = {
"positive": 1.0,
"negative": -1.0,
"neutral": 0.0,
"mixed": 0.0,
}
return mapping.get(sentiment.lower(), 0.0)
def weighted_sentiment_average(signals: list[WeightedSignal]) -> float:
"""Compute a weight-adjusted average sentiment across signals.
Returns a value in [-1, 1]. Returns 0.0 when total weight is zero.
"""
total_weight = 0.0
weighted_sum = 0.0
for sig in signals:
w = sig.weight.combined * sig.impact_score
weighted_sum += w * sig.sentiment_value
total_weight += w
if total_weight == 0.0:
return 0.0
return weighted_sum / total_weight
+650 -1
View File
@@ -1 +1,650 @@
"""Aggregation worker - rolling trend summaries, contradiction detection, evidence ranking."""
"""Aggregation worker - company-level rolling window trend summaries.
Queries document intelligence and market context for a given ticker,
computes weighted signal scores, and produces TrendSummary objects
persisted to the trend_windows table.
Requirements: 6.1, 6.2, 6.5
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
import asyncpg
from services.aggregation.contradiction import CatalystEntry, detect_contradictions
from services.aggregation.evidence import (
EvidenceRankConfig,
RankedEvidence,
rank_evidence as _rank_evidence_composite,
rank_evidence_detailed,
)
from services.aggregation.market_context import fetch_market_context
from services.aggregation.scoring import (
ScoringConfig,
WeightedSignal,
compute_signal_weight,
sentiment_to_numeric,
weighted_sentiment_average,
)
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
from services.shared.metrics import (
AGGREGATION_CONTRADICTION_SCORE,
AGGREGATION_DURATION,
AGGREGATION_SIGNALS_PROCESSED,
AGGREGATION_WINDOWS_COMPUTED,
)
logger = logging.getLogger(__name__)
# Map TrendWindow values to lookback durations.
WINDOW_DURATIONS: dict[str, timedelta] = {
TrendWindow.INTRADAY.value: timedelta(hours=12),
TrendWindow.ONE_DAY.value: timedelta(days=1),
TrendWindow.SEVEN_DAY.value: timedelta(days=7),
TrendWindow.THIRTY_DAY.value: timedelta(days=30),
TrendWindow.NINETY_DAY.value: timedelta(days=90),
}
# How many evidence document IDs to keep in supporting/opposing lists.
MAX_EVIDENCE_REFS = 10
@dataclass
class AggregationConfig:
"""Controls which windows to compute and scoring parameters."""
windows: list[str] | None = None # None = all windows
scoring: ScoringConfig | None = None
max_evidence: int = MAX_EVIDENCE_REFS
def effective_windows(self) -> list[str]:
if self.windows:
return self.windows
return [w.value for w in TrendWindow]
def effective_scoring(self) -> ScoringConfig:
return self.scoring or ScoringConfig()
# ---------------------------------------------------------------------------
# Fetch impact records for a ticker within a time window
# ---------------------------------------------------------------------------
_IMPACT_QUERY = """
SELECT
di.document_id,
di.confidence,
di.novelty_score,
di.source_credibility,
dir.sentiment,
dir.impact_score,
dir.catalyst_type,
dir.key_facts,
dir.risks,
d.published_at
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND d.published_at >= $2
AND d.published_at <= $3
AND di.validation_status = 'valid'
AND d.status != 'rejected'
ORDER BY d.published_at DESC
"""
@dataclass
class ImpactRow:
"""Parsed row from the impact query."""
document_id: str
confidence: float
novelty_score: float
source_credibility: float
sentiment: str
impact_score: float
catalyst_type: str
key_facts: list[str]
risks: list[str]
published_at: datetime
def _parse_impact_row(row: Any) -> ImpactRow:
"""Convert an asyncpg Record to an ImpactRow."""
key_facts = row["key_facts"]
if isinstance(key_facts, str):
key_facts = json.loads(key_facts)
risks = row["risks"]
if isinstance(risks, str):
risks = json.loads(risks)
return ImpactRow(
document_id=str(row["document_id"]),
confidence=float(row["confidence"] or 0.5),
novelty_score=float(row["novelty_score"] or 0.5),
source_credibility=float(row["source_credibility"] or 0.5),
sentiment=row["sentiment"] or "neutral",
impact_score=float(row["impact_score"] or 0.0),
catalyst_type=row["catalyst_type"] or "other",
key_facts=key_facts if isinstance(key_facts, list) else [],
risks=risks if isinstance(risks, list) else [],
published_at=row["published_at"],
)
async def fetch_impact_records(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
window_end: datetime,
) -> list[ImpactRow]:
"""Fetch validated document impact records for a ticker in a time range."""
rows = await pool.fetch(_IMPACT_QUERY, ticker, window_start, window_end)
return [_parse_impact_row(r) for r in rows]
# ---------------------------------------------------------------------------
# Build weighted signals from impact records
# ---------------------------------------------------------------------------
def build_weighted_signals(
impacts: list[ImpactRow],
reference_time: datetime,
window: str,
market_ctx: Any | None = None,
config: ScoringConfig | None = None,
) -> list[WeightedSignal]:
"""Convert impact records into WeightedSignal objects using the scoring module."""
cfg = config or ScoringConfig()
signals: list[WeightedSignal] = []
for imp in impacts:
sw = compute_signal_weight(
published_at=imp.published_at,
reference_time=reference_time,
window=window,
source_credibility=imp.source_credibility,
novelty_score=imp.novelty_score,
extraction_confidence=imp.confidence,
market_ctx=market_ctx,
config=cfg,
)
signals.append(
WeightedSignal(
document_id=imp.document_id,
weight=sw,
sentiment_value=sentiment_to_numeric(imp.sentiment),
impact_score=imp.impact_score,
)
)
return signals
# ---------------------------------------------------------------------------
# Derive trend direction from weighted sentiment
# ---------------------------------------------------------------------------
# Thresholds for mapping numeric sentiment to direction.
BULLISH_THRESHOLD = 0.15
BEARISH_THRESHOLD = -0.15
MIXED_THRESHOLD = 0.10 # contradiction score above this → mixed
def derive_trend_direction(
avg_sentiment: float,
contradiction_score: float = 0.0,
) -> TrendDirection:
"""Map a weighted average sentiment to a TrendDirection.
If contradiction is high, the direction is MIXED regardless of
the average sentiment value.
"""
if contradiction_score > MIXED_THRESHOLD and abs(avg_sentiment) < 0.3:
return TrendDirection.MIXED
if avg_sentiment >= BULLISH_THRESHOLD:
return TrendDirection.BULLISH
if avg_sentiment <= BEARISH_THRESHOLD:
return TrendDirection.BEARISH
return TrendDirection.NEUTRAL
# ---------------------------------------------------------------------------
# Compute contradiction score
# ---------------------------------------------------------------------------
def compute_contradiction_score(signals: list[WeightedSignal]) -> float:
"""Measure how much disagreement exists among weighted signals.
Returns a value in [0, 1] where 0 means full agreement and 1 means
equal-weight positive and negative signals.
The formula computes the ratio of the minority-side total weight to
the majority-side total weight.
"""
if not signals:
return 0.0
pos_weight = 0.0
neg_weight = 0.0
for sig in signals:
w = sig.weight.combined * sig.impact_score
if sig.sentiment_value > 0:
pos_weight += w
elif sig.sentiment_value < 0:
neg_weight += w
total = pos_weight + neg_weight
if total == 0.0:
return 0.0
minority = min(pos_weight, neg_weight)
return round(minority / total, 4)
# ---------------------------------------------------------------------------
# Rank evidence (supporting vs opposing)
# ---------------------------------------------------------------------------
def rank_evidence(
signals: list[WeightedSignal],
max_refs: int = MAX_EVIDENCE_REFS,
) -> tuple[list[str], list[str]]:
"""Return top supporting and opposing document IDs ranked by composite score.
Delegates to the evidence ranking module which considers multiple
factors (weight, impact, recency, confidence) rather than raw weight alone.
Supporting = positive sentiment, Opposing = negative sentiment.
Neutral/mixed signals are excluded from evidence lists.
"""
config = EvidenceRankConfig(max_refs=max_refs)
return _rank_evidence_composite(signals, config)
# ---------------------------------------------------------------------------
# Extract dominant catalysts and material risks
# ---------------------------------------------------------------------------
def extract_catalysts_and_risks(
impacts: list[ImpactRow],
signals: list[WeightedSignal],
) -> tuple[list[str], list[str]]:
"""Return dominant catalyst types and material risks weighted by signal strength.
Catalysts are ranked by cumulative weight. Risks are deduplicated and
ordered by the weight of the signal that surfaced them.
"""
catalyst_weights: dict[str, float] = {}
risk_entries: list[tuple[float, str]] = []
# Build a lookup from document_id to combined weight
weight_by_doc = {s.document_id: s.weight.combined * s.impact_score for s in signals}
for imp in impacts:
w = weight_by_doc.get(imp.document_id, 0.0)
if w <= 0.0:
continue
catalyst_weights[imp.catalyst_type] = catalyst_weights.get(imp.catalyst_type, 0.0) + w
for risk in imp.risks:
risk_entries.append((w, risk))
# Top catalysts by cumulative weight
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
catalysts = [cat for cat, _ in sorted_catalysts[:5]]
# Deduplicated risks ordered by weight
seen_risks: set[str] = set()
risks: list[str] = []
risk_entries.sort(key=lambda x: x[0], reverse=True)
for _, risk_text in risk_entries:
normalized = risk_text.strip().lower()
if normalized not in seen_risks:
seen_risks.add(normalized)
risks.append(risk_text.strip())
if len(risks) >= 5:
break
return catalysts, risks
# ---------------------------------------------------------------------------
# Compute trend confidence
# ---------------------------------------------------------------------------
def compute_trend_confidence(
signals: list[WeightedSignal],
contradiction_score: float,
) -> float:
"""Derive an overall confidence for the trend summary.
Confidence is based on:
- Number of contributing signals (more = higher base)
- Average extraction confidence of contributing signals
- Contradiction penalty (high contradiction lowers confidence)
Returns a value in [0, 1].
"""
if not signals:
return 0.0
active = [s for s in signals if s.weight.combined > 0]
if not active:
return 0.0
# Base confidence from signal count (diminishing returns)
count_factor = min(len(active) / 20.0, 1.0)
# Average extraction confidence (from the confidence_gate — if gated,
# the signal wouldn't be in active list, so we use the raw confidence
# from the weight breakdown).
avg_conf = sum(s.weight.credibility for s in active) / len(active)
# Contradiction penalty
contradiction_penalty = contradiction_score * 0.4
confidence = (0.4 * count_factor + 0.6 * avg_conf) - contradiction_penalty
return round(max(0.0, min(1.0, confidence)), 4)
# ---------------------------------------------------------------------------
# Assemble a TrendSummary from components
# ---------------------------------------------------------------------------
@dataclass
class AssembledTrend:
"""A trend summary paired with its detailed evidence rankings."""
summary: TrendSummary
supporting_evidence: list[RankedEvidence]
opposing_evidence: list[RankedEvidence]
def assemble_trend_summary(
ticker: str,
window: str,
signals: list[WeightedSignal],
impacts: list[ImpactRow],
market_ctx: Any | None = None,
max_evidence: int = MAX_EVIDENCE_REFS,
reference_time: datetime | None = None,
) -> TrendSummary:
"""Build a complete TrendSummary from weighted signals and impact records."""
result = assemble_trend_with_evidence(
ticker, window, signals, impacts, market_ctx, max_evidence, reference_time,
)
return result.summary
def assemble_trend_with_evidence(
ticker: str,
window: str,
signals: list[WeightedSignal],
impacts: list[ImpactRow],
market_ctx: Any | None = None,
max_evidence: int = MAX_EVIDENCE_REFS,
reference_time: datetime | None = None,
) -> AssembledTrend:
"""Build a TrendSummary and return detailed evidence rankings for persistence."""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
avg_sentiment = weighted_sentiment_average(signals)
# Run full contradiction detection (Requirement 6.4)
catalyst_entries = [
CatalystEntry(document_id=imp.document_id, catalyst_type=imp.catalyst_type)
for imp in impacts
]
contradiction_result = detect_contradictions(signals, catalyst_entries)
contradiction = contradiction_result.score
direction = derive_trend_direction(avg_sentiment, contradiction)
confidence = compute_trend_confidence(signals, contradiction)
# Get detailed evidence rankings for persistence
config = EvidenceRankConfig(max_refs=max_evidence)
supporting_ranked, opposing_ranked = rank_evidence_detailed(signals, config)
supporting = [r.document_id for r in supporting_ranked]
opposing = [r.document_id for r in opposing_ranked]
catalysts, risks = extract_catalysts_and_risks(impacts, signals)
# Trend strength: absolute value of weighted sentiment, clamped to [0, 1]
strength = round(min(abs(avg_sentiment), 1.0), 4)
summary = TrendSummary(
entity_type="company",
entity_id=ticker,
window=TrendWindow(window),
trend_direction=direction,
trend_strength=strength,
confidence=confidence,
top_supporting_evidence=supporting,
top_opposing_evidence=opposing,
dominant_catalysts=catalysts,
material_risks=risks,
contradiction_score=contradiction,
disagreement_details=contradiction_result.details,
market_context=market_ctx,
generated_at=reference_time,
)
return AssembledTrend(
summary=summary,
supporting_evidence=supporting_ranked,
opposing_evidence=opposing_ranked,
)
# ---------------------------------------------------------------------------
# Persist trend summary to PostgreSQL
# ---------------------------------------------------------------------------
_UPSERT_TREND = """
INSERT INTO trend_windows (
entity_type, entity_id, window, trend_direction, trend_strength,
confidence, top_supporting_evidence, top_opposing_evidence,
dominant_catalysts, material_risks, contradiction_score,
disagreement_details, market_context, generated_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7::jsonb, $8::jsonb,
$9::jsonb, $10::jsonb, $11,
$12::jsonb, $13::jsonb, $14
)
RETURNING id
"""
async def persist_trend_summary(
pool: asyncpg.Pool,
summary: TrendSummary,
) -> str:
"""Insert a trend summary row and return its UUID."""
row = await pool.fetchrow(
_UPSERT_TREND,
summary.entity_type,
summary.entity_id,
summary.window.value,
summary.trend_direction.value,
summary.trend_strength,
summary.confidence,
json.dumps(summary.top_supporting_evidence),
json.dumps(summary.top_opposing_evidence),
json.dumps(summary.dominant_catalysts),
json.dumps(summary.material_risks),
summary.contradiction_score,
json.dumps([d.model_dump() for d in summary.disagreement_details]),
json.dumps(summary.market_context.model_dump() if summary.market_context else {}),
summary.generated_at,
)
return str(row["id"])
# ---------------------------------------------------------------------------
# Persist evidence mappings to trend_evidence table
# ---------------------------------------------------------------------------
_INSERT_EVIDENCE = """
INSERT INTO trend_evidence (
trend_window_id, document_id, evidence_type,
rank_score, weight_component, impact_component,
recency_component, confidence_component, sentiment_value
) VALUES (
$1, $2::uuid, $3,
$4, $5, $6,
$7, $8, $9
)
"""
async def persist_trend_evidence(
pool: asyncpg.Pool,
trend_window_id: str,
supporting: list[RankedEvidence],
opposing: list[RankedEvidence],
) -> int:
"""Insert evidence mapping rows for a trend window. Returns count inserted."""
rows: list[tuple[str, str, str, float, float, float, float, float, float]] = []
for ev in supporting:
rows.append((
trend_window_id, ev.document_id, "supporting",
ev.rank_score, ev.weight_component, ev.impact_component,
ev.recency_component, ev.confidence_component, ev.sentiment_value,
))
for ev in opposing:
rows.append((
trend_window_id, ev.document_id, "opposing",
ev.rank_score, ev.weight_component, ev.impact_component,
ev.recency_component, ev.confidence_component, ev.sentiment_value,
))
if not rows:
return 0
await pool.executemany(_INSERT_EVIDENCE, rows)
return len(rows)
# ---------------------------------------------------------------------------
# Main aggregation entry point for a single ticker + window
# ---------------------------------------------------------------------------
async def aggregate_company_window(
pool: asyncpg.Pool,
ticker: str,
window: str,
reference_time: datetime | None = None,
config: AggregationConfig | None = None,
) -> TrendSummary:
"""Compute and persist a trend summary for one ticker and one window.
Steps:
1. Determine the time range for the window.
2. Fetch document impact records from PostgreSQL.
3. Fetch market context for the ticker.
4. Build weighted signals using the scoring module.
5. Assemble the TrendSummary.
6. Persist to trend_windows table.
Returns the assembled TrendSummary.
"""
cfg = config or AggregationConfig()
scoring_cfg = cfg.effective_scoring()
if reference_time is None:
reference_time = datetime.now(timezone.utc)
_agg_start = time.monotonic()
duration = WINDOW_DURATIONS.get(window, timedelta(days=7))
window_start = reference_time - duration
# 1. Fetch impact records
impacts = await fetch_impact_records(pool, ticker, window_start, reference_time)
# 2. Fetch market context
market_ctx = await fetch_market_context(pool, ticker, window, reference_time)
# 3. Build weighted signals
signals = build_weighted_signals(
impacts, reference_time, window, market_ctx, scoring_cfg,
)
# 4. Assemble trend summary with evidence details
assembled = assemble_trend_with_evidence(
ticker=ticker,
window=window,
signals=signals,
impacts=impacts,
market_ctx=market_ctx if market_ctx.has_data else None,
max_evidence=cfg.max_evidence,
reference_time=reference_time,
)
summary = assembled.summary
# 5. Persist trend window
trend_id = await persist_trend_summary(pool, summary)
# 6. Persist evidence mappings
evidence_count = await persist_trend_evidence(
pool, trend_id,
assembled.supporting_evidence,
assembled.opposing_evidence,
)
logger.info(
"Persisted trend %s for %s/%s: direction=%s strength=%.3f confidence=%.3f signals=%d evidence=%d",
trend_id, ticker, window, summary.trend_direction.value,
summary.trend_strength, summary.confidence, len(signals), evidence_count,
)
# Prometheus metrics
AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc()
AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals))
AGGREGATION_CONTRADICTION_SCORE.observe(summary.contradiction_score)
AGGREGATION_DURATION.labels(window=window).observe(time.monotonic() - _agg_start)
return summary
# ---------------------------------------------------------------------------
# Aggregate all windows for a single ticker
# ---------------------------------------------------------------------------
async def aggregate_company(
pool: asyncpg.Pool,
ticker: str,
reference_time: datetime | None = None,
config: AggregationConfig | None = None,
) -> list[TrendSummary]:
"""Compute trend summaries for all configured windows for a ticker."""
cfg = config or AggregationConfig()
if reference_time is None:
reference_time = datetime.now(timezone.utc)
summaries: list[TrendSummary] = []
for window in cfg.effective_windows():
summary = await aggregate_company_window(
pool, ticker, window, reference_time, cfg,
)
summaries.append(summary)
return summaries
+1507 -1
View File
File diff suppressed because it is too large Load Diff
+268
View File
@@ -0,0 +1,268 @@
"""Ollama client wrapper using structured output format.
Sends documents to a local Ollama instance via the /api/chat endpoint
with the ``format`` parameter set to the extraction JSON schema, ensuring
the model returns schema-compliant JSON.
Includes retry logic for invalid or incomplete model responses with
exponential backoff, error classification, and full audit preservation.
Requirements: 5.1, 5.2, 5.4
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
import httpx
from services.extractor.prompts import (
build_extraction_prompt,
get_json_schema,
get_prompt_metadata,
)
from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction
from services.shared.config import OllamaConfig
logger = logging.getLogger("ollama_client")
# Errors that should NOT be retried — the request itself is bad.
_NON_RETRYABLE_ERRORS = frozenset({
"http_400",
"http_401",
"http_403",
"http_404",
"http_422",
})
def _is_retryable(error: str | None) -> bool:
"""Determine whether an extraction error warrants a retry."""
if error is None:
return False
return error not in _NON_RETRYABLE_ERRORS
@dataclass
class ExtractionAttempt:
"""Record of a single extraction attempt for audit."""
raw_output: str = ""
validation: ValidationReport | None = None
error: str | None = None
duration_ms: int = 0
model: str = ""
retryable: bool = True
@dataclass
class ExtractionResponse:
"""Full response from an extraction call, including all attempts."""
success: bool = False
result: ExtractionResult | None = None
attempts: list[ExtractionAttempt] = field(default_factory=list)
prompt_metadata: dict[str, str] = field(default_factory=dict)
model: str = ""
total_duration_ms: int = 0
def _compute_backoff(
attempt_num: int,
base_delay: float,
max_delay: float,
multiplier: float,
) -> float:
"""Compute exponential backoff delay for a given attempt number."""
delay = base_delay * (multiplier ** attempt_num)
return min(delay, max_delay)
class OllamaClient:
"""Async client for Ollama structured extraction.
Usage::
config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b")
client = OllamaClient(config)
response = await client.extract(
document_text="Apple reported record earnings...",
document_type="article",
document_id="abc-123",
)
if response.success:
print(response.result)
"""
_config: OllamaConfig
_max_retries: int
_base_delay: float
_max_delay: float
_backoff_multiplier: float
_owns_client: bool
_http: httpx.AsyncClient
def __init__(
self,
config: OllamaConfig,
max_retries: int | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._config = config
self._max_retries = max_retries if max_retries is not None else config.max_retries
self._base_delay = config.retry_base_delay
self._max_delay = config.retry_max_delay
self._backoff_multiplier = config.retry_backoff_multiplier
self._owns_client = http_client is None
self._http = http_client or httpx.AsyncClient(timeout=config.timeout)
async def close(self) -> None:
"""Close the underlying HTTP client if we own it."""
if self._owns_client:
await self._http.aclose()
async def extract(
self,
document_text: str,
document_type: str = "article",
document_id: str = "",
known_tickers: list[str] | None = None,
) -> ExtractionResponse:
"""Send a document to Ollama for structured intelligence extraction.
Retries up to ``max_retries`` times when the model returns invalid
or incomplete JSON. Uses exponential backoff between retries.
Non-retryable errors (e.g. HTTP 400) stop retries immediately.
Each attempt and its validation result are preserved for audit.
Args:
document_text: Normalized text content of the document.
document_type: One of article, filing, transcript, press_release.
document_id: Optional document ID for traceability.
known_tickers: Optional ticker hints for the model.
Returns:
An ``ExtractionResponse`` with the parsed result on success.
"""
prompts = build_extraction_prompt(
document_text=document_text,
document_type=document_type,
document_id=document_id,
known_tickers=known_tickers,
)
json_schema = get_json_schema()
prompt_meta = get_prompt_metadata()
response = ExtractionResponse(
prompt_metadata=prompt_meta,
model=self._config.model,
)
total_start = time.monotonic()
for attempt_num in range(self._max_retries + 1):
attempt = await self._call_ollama(prompts, json_schema, document_text)
response.attempts.append(attempt)
if attempt.error is None and attempt.validation and attempt.validation.valid:
response.success = True
response.result = attempt.validation.parsed
break
# Check if the error is non-retryable — stop immediately
if not _is_retryable(attempt.error):
attempt.retryable = False
logger.warning(
"Non-retryable error for doc %s: %s — stopping retries",
document_id or "unknown",
attempt.error,
)
break
if attempt_num < self._max_retries:
delay = _compute_backoff(
attempt_num,
self._base_delay,
self._max_delay,
self._backoff_multiplier,
)
logger.warning(
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1,
self._max_retries + 1,
document_id or "unknown",
attempt.error or "validation failed",
delay,
)
await asyncio.sleep(delay)
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
return response
async def _call_ollama(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Make a single call to the Ollama /api/chat endpoint."""
attempt = ExtractionAttempt(model=self._config.model)
start = time.monotonic()
payload = {
"model": self._config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"format": json_schema,
"stream": False,
}
try:
resp = await self._http.post(
f"{self._config.base_url}/api/chat",
json=payload,
)
_ = resp.raise_for_status()
except httpx.TimeoutException:
attempt.error = "timeout"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPStatusError as exc:
attempt.error = f"http_{exc.response.status_code}"
attempt.retryable = _is_retryable(attempt.error)
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPError as exc:
attempt.error = f"connection_error: {exc}"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
attempt.duration_ms = int((time.monotonic() - start) * 1000)
# Parse the Ollama response envelope
try:
body: dict[str, object] = resp.json()
except json.JSONDecodeError:
attempt.error = "invalid_response_json"
attempt.raw_output = resp.text
return attempt
msg = body.get("message")
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
attempt.raw_output = content
if not content:
attempt.error = "empty_model_response"
return attempt
# Validate against extraction schema
attempt.validation = validate_extraction(content, document_text=document_text)
if not attempt.validation.valid:
attempt.error = "; ".join(attempt.validation.errors)
return attempt
+72
View File
@@ -0,0 +1,72 @@
"""Extractor worker entrypoint - polls Redis for extraction jobs."""
from __future__ import annotations
import asyncio
import logging
import asyncpg
from minio import Minio
from services.extractor.client import OllamaClient
from services.extractor.worker import persist_extraction
from services.shared.config import load_config
from services.shared.logging import setup_logging
from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key
logger = logging.getLogger("extractor_main")
async def main() -> None:
config = load_config()
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
minio_client = Minio(
config.minio.endpoint,
access_key=config.minio.access_key,
secret_key=config.minio.secret_key,
secure=config.minio.secure,
)
ollama = OllamaClient(config.ollama)
import json
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
logger.info("Extractor worker started, polling %s", queue)
try:
while True:
raw = await redis_client.lpop(queue)
if raw is None:
await asyncio.sleep(1)
continue
payload = raw
job = json.loads(payload)
document_id = job.get("document_id", "")
ticker = job.get("ticker", "")
text = job.get("text", "")
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
try:
extraction_response = await ollama.extract(text)
await persist_extraction(
pool=pool,
minio_client=minio_client,
document_id=document_id,
ticker=ticker,
extraction_response=extraction_response,
document_text_length=len(text),
)
except Exception:
logger.exception("Extraction failed for doc %s", document_id)
finally:
await pool.close()
await redis_client.close()
if __name__ == "__main__":
asyncio.run(main())
+250
View File
@@ -0,0 +1,250 @@
"""Model performance metrics collection and persistence.
Tracks extraction success/failure rates, latency percentiles, retry counts,
validation error distributions, confidence scores, and token usage estimates.
Metrics are persisted to PostgreSQL for operational dashboards and published
to the analytical lake for Trino/Superset queries.
Requirements: 5.2, 5.4, 12.1, 12.2
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.client import ExtractionResponse
logger = logging.getLogger("extractor_metrics")
# Rough token estimate: ~4 chars per token for English text
_CHARS_PER_TOKEN = 4
@dataclass
class ExtractionMetrics:
"""Metrics extracted from a single extraction run."""
document_id: str = ""
ticker: str = ""
model_name: str = ""
prompt_version: str = ""
schema_version: str = ""
success: bool = False
attempt_count: int = 0
total_duration_ms: int = 0
first_attempt_duration_ms: int = 0
final_attempt_duration_ms: int = 0
confidence: float = 0.0
validation_status: str = "unknown"
validation_error_count: int = 0
validation_warning_count: int = 0
validation_errors: list[str] = field(default_factory=list)
retry_count: int = 0
input_token_estimate: int = 0
output_token_estimate: int = 0
company_count: int = 0
recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def collect_metrics(
extraction_response: ExtractionResponse,
*,
document_id: str = "",
ticker: str = "",
document_text_length: int = 0,
) -> ExtractionMetrics:
"""Collect metrics from an ExtractionResponse.
Args:
extraction_response: The full response from OllamaClient.extract().
document_id: UUID of the source document.
ticker: Primary ticker symbol.
document_text_length: Length of the input document text in characters.
Returns:
An ExtractionMetrics dataclass with all computed fields.
"""
attempts = extraction_response.attempts
first_dur = attempts[0].duration_ms if attempts else 0
final_dur = attempts[-1].duration_ms if attempts else 0
# Gather validation info from the final attempt
final_attempt = attempts[-1] if attempts else None
val_errors: list[str] = []
val_warnings: list[str] = []
if final_attempt and final_attempt.validation:
val_errors = final_attempt.validation.errors
val_warnings = final_attempt.validation.warnings
# Determine validation status
if extraction_response.success:
validation_status = "valid"
elif attempts:
validation_status = "failed"
else:
validation_status = "unknown"
# Confidence from the result, or 0 if failed
confidence = 0.0
company_count = 0
if extraction_response.result:
confidence = extraction_response.result.confidence
company_count = len(extraction_response.result.companies)
# Token estimates
input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0
output_tokens = 0
if final_attempt and final_attempt.raw_output:
output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN
return ExtractionMetrics(
document_id=document_id,
ticker=ticker,
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
success=extraction_response.success,
attempt_count=len(attempts),
total_duration_ms=extraction_response.total_duration_ms,
first_attempt_duration_ms=first_dur,
final_attempt_duration_ms=final_dur,
confidence=confidence,
validation_status=validation_status,
validation_error_count=len(val_errors),
validation_warning_count=len(val_warnings),
validation_errors=val_errors,
retry_count=max(0, len(attempts) - 1),
input_token_estimate=input_tokens,
output_token_estimate=output_tokens,
company_count=company_count,
)
async def persist_metrics(
pool: asyncpg.Pool,
metrics: ExtractionMetrics,
) -> str:
"""Persist extraction metrics to the model_performance_metrics table.
Args:
pool: PostgreSQL connection pool.
metrics: Collected metrics from an extraction run.
Returns:
The UUID of the inserted metrics row.
"""
row_id = await pool.fetchval(
"""INSERT INTO model_performance_metrics
(document_id, ticker, model_name, prompt_version, schema_version,
success, attempt_count, total_duration_ms,
first_attempt_duration_ms, final_attempt_duration_ms,
confidence, validation_status, validation_error_count,
validation_warning_count, validation_errors, retry_count,
input_token_estimate, output_token_estimate, company_count,
recorded_at)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20)
RETURNING id""",
metrics.document_id,
metrics.ticker,
metrics.model_name,
metrics.prompt_version,
metrics.schema_version,
metrics.success,
metrics.attempt_count,
metrics.total_duration_ms,
metrics.first_attempt_duration_ms,
metrics.final_attempt_duration_ms,
metrics.confidence,
metrics.validation_status,
metrics.validation_error_count,
metrics.validation_warning_count,
json.dumps(metrics.validation_errors),
metrics.retry_count,
metrics.input_token_estimate,
metrics.output_token_estimate,
metrics.company_count,
metrics.recorded_at,
)
logger.info(
"Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d",
row_id, metrics.document_id, metrics.success,
metrics.total_duration_ms, metrics.retry_count,
)
return str(row_id)
async def get_model_performance_summary(
pool: asyncpg.Pool,
*,
model_name: str | None = None,
hours: int = 24,
) -> dict[str, object]:
"""Query aggregated model performance metrics for dashboards.
Returns a summary dict with success rate, avg latency, retry rate,
confidence distribution, and error breakdown for the given time window.
Args:
pool: PostgreSQL connection pool.
model_name: Optional filter by model name.
hours: Lookback window in hours (default 24).
Returns:
Dict with aggregated performance metrics.
"""
model_filter = "AND model_name = $2" if model_name else ""
params: list[object] = [hours]
if model_name:
params.append(model_name)
row = await pool.fetchrow(
f"""SELECT
COUNT(*) AS total_extractions,
COUNT(*) FILTER (WHERE success) AS successful,
COUNT(*) FILTER (WHERE NOT success) AS failed,
ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms,
ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms,
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms,
ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms,
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
ROUND(AVG(confidence)::numeric, 3) AS avg_confidence,
SUM(input_token_estimate) AS total_input_tokens,
SUM(output_token_estimate) AS total_output_tokens,
ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc,
ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors,
ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings
FROM model_performance_metrics
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1
{model_filter}""",
*params,
)
if not row or row["total_extractions"] == 0:
return {"total_extractions": 0, "success_rate": 0.0}
total = row["total_extractions"]
successful = row["successful"]
return {
"total_extractions": total,
"successful": successful,
"failed": row["failed"],
"success_rate": round(successful / total, 4) if total > 0 else 0.0,
"avg_duration_ms": float(row["avg_duration_ms"] or 0),
"p50_duration_ms": float(row["p50_duration_ms"] or 0),
"p95_duration_ms": float(row["p95_duration_ms"] or 0),
"p99_duration_ms": float(row["p99_duration_ms"] or 0),
"avg_retries": float(row["avg_retries"] or 0),
"avg_confidence": float(row["avg_confidence"] or 0),
"total_input_tokens": int(row["total_input_tokens"] or 0),
"total_output_tokens": int(row["total_output_tokens"] or 0),
"avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0),
"avg_validation_errors": float(row["avg_validation_errors"] or 0),
"avg_validation_warnings": float(row["avg_validation_warnings"] or 0),
"hours": hours,
}
+149
View File
@@ -0,0 +1,149 @@
"""Extraction prompt templates with anti-hallucination instructions.
Builds structured prompts for Ollama document intelligence extraction.
Each prompt includes the target JSON schema, anti-hallucination rules,
and document-type-specific guidance.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
from typing import Any
from services.extractor.schemas import generate_json_schema, SCHEMA_VERSION
from services.shared.schemas import (
DocumentType,
)
PROMPT_VERSION = "document-intel-v1"
# --- JSON schema for structured output (generated from Pydantic models) ---
EXTRACTION_JSON_SCHEMA: dict[str, Any] = generate_json_schema()
# --- Anti-hallucination system prompt ---
SYSTEM_PROMPT = """\
You are a financial document analysis system. You extract structured intelligence \
from financial documents into JSON.
STRICT RULES — VIOLATIONS WILL INVALIDATE YOUR OUTPUT:
1. ONLY extract information explicitly stated in the document text provided.
2. NEVER fabricate facts, quotes, numbers, dates, or company names.
3. NEVER infer information that is not directly supported by the text.
4. If the document does not mention a company, do NOT include that company.
5. If the document is ambiguous about sentiment or impact, use "neutral" or "mixed" \
and set confidence lower.
6. evidence_spans MUST be short verbatim quotes copied from the document. \
Do NOT paraphrase or invent quotes.
7. key_facts MUST be directly stated in the document. Do NOT add external knowledge.
8. If you are uncertain about any field, lower the confidence score and add a warning \
to extraction_warnings.
9. If the document text is too short, garbled, or uninformative, return an empty \
companies array, set confidence below 0.3, and add "insufficient_content" to warnings.
10. Return ONLY valid JSON matching the provided schema. No commentary, no markdown fences."""
# --- Document-type-specific guidance ---
_DOCTYPE_GUIDANCE: dict[str, str] = {
DocumentType.ARTICLE: (
"This is a news article. Focus on reported facts, quoted sources, and stated "
"analyst opinions. Distinguish between the journalist's framing and actual "
"company developments. Do not treat speculative language as confirmed fact."
),
DocumentType.FILING: (
"This is a regulatory filing (e.g. SEC 10-K, 10-Q, 8-K). Extract concrete "
"financial figures, risk factors, and material events as stated. Filings use "
"precise legal language — preserve that precision in your extraction."
),
DocumentType.TRANSCRIPT: (
"This is an earnings call or event transcript. Distinguish between management "
"forward-looking statements and reported results. Flag forward-looking language "
"as lower confidence. Extract specific guidance numbers when stated."
),
DocumentType.PRESS_RELEASE: (
"This is a company press release. Be aware that press releases are promotional. "
"Extract stated facts and figures but note that sentiment may be biased positive. "
"Look for concrete metrics rather than marketing language."
),
}
def _get_doctype_guidance(document_type: str) -> str:
"""Return document-type-specific extraction guidance."""
return _DOCTYPE_GUIDANCE.get(document_type, _DOCTYPE_GUIDANCE[DocumentType.ARTICLE])
# --- Prompt builder ---
def build_extraction_prompt(
document_text: str,
document_type: str = DocumentType.ARTICLE,
known_tickers: list[str] | None = None,
document_id: str = "",
) -> dict[str, str]:
"""Build system and user prompts for Ollama structured extraction.
Args:
document_text: Normalized text content of the document.
document_type: One of the DocumentType enum values.
known_tickers: Optional list of tickers the document may reference.
Helps the model focus but does NOT mean all tickers are relevant.
document_id: Optional document ID for traceability.
Returns:
Dict with 'system' and 'user' prompt strings.
"""
doctype_guidance = _get_doctype_guidance(document_type)
ticker_hint = ""
if known_tickers:
tickers_str = ", ".join(known_tickers)
ticker_hint = (
f"\nThe following tickers may be referenced in this document: {tickers_str}\n"
"Only include a ticker in your output if the document actually discusses that company. "
"Do NOT include a ticker just because it appears in this hint."
)
schema_str = json.dumps(EXTRACTION_JSON_SCHEMA, indent=2)
doc_id_line = f"Document ID: {document_id}\n" if document_id else ""
user_prompt = f"""\
Extract structured intelligence from the following document.
{doc_id_line}Document type: {document_type}
{doctype_guidance}
{ticker_hint}
Your output MUST be a single JSON object conforming to this schema:
{schema_str}
REMEMBER:
- Only extract what is explicitly in the text below.
- evidence_spans must be verbatim quotes from the text.
- If the text is insufficient, return empty companies and low confidence.
- Return ONLY the JSON object. No other text.
--- DOCUMENT TEXT ---
{document_text}
--- END DOCUMENT TEXT ---"""
return {
"system": SYSTEM_PROMPT,
"user": user_prompt,
}
def get_prompt_metadata() -> dict[str, str]:
"""Return metadata about the current prompt version for audit trails."""
return {
"prompt_version": PROMPT_VERSION,
"schema_version": SCHEMA_VERSION,
}
def get_json_schema() -> dict[str, Any]:
"""Return the extraction JSON schema for Ollama structured output format parameter."""
return EXTRACTION_JSON_SCHEMA
+250
View File
@@ -0,0 +1,250 @@
"""Replay dataset loader and runner for deterministic extraction testing.
Loads archived document fixtures from JSON files, validates their expected
extraction outputs against the current schema, and provides a runner that
can compare live Ollama extraction results against expected baselines.
This enables:
- Schema regression testing: verify expected outputs still pass validation
- Prompt regression testing: detect drift when prompts or schemas change
- End-to-end replay: run fixtures through a live Ollama and compare
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from services.extractor.schemas import (
ExtractionResult,
ValidationReport,
get_schema_version,
validate_extraction,
)
logger = logging.getLogger("extractor_replay")
FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures"
@dataclass
class ReplayFixture:
"""A single replay fixture loaded from disk."""
document_id: str
document_type: str
document_text: str
known_tickers: list[str]
expected_extraction: dict[str, Any]
metadata: dict[str, str]
source_path: str = ""
@property
def expected_result(self) -> ExtractionResult:
"""Parse expected_extraction into a validated ExtractionResult."""
return ExtractionResult.model_validate(self.expected_extraction)
@dataclass
class ReplayValidationResult:
"""Result of validating a single fixture against the current schema."""
fixture_id: str
schema_valid: bool = False
validation_report: ValidationReport | None = None
schema_version: str = ""
error: str | None = None
@dataclass
class ReplayComparisonResult:
"""Result of comparing a live extraction against the expected baseline."""
fixture_id: str
expected_companies: list[str] = field(default_factory=list)
actual_companies: list[str] = field(default_factory=list)
companies_match: bool = False
expected_sentiment_map: dict[str, str] = field(default_factory=dict)
actual_sentiment_map: dict[str, str] = field(default_factory=dict)
sentiment_match: bool = False
expected_catalyst_map: dict[str, str] = field(default_factory=dict)
actual_catalyst_map: dict[str, str] = field(default_factory=dict)
catalyst_match: bool = False
actual_schema_valid: bool = False
warnings: list[str] = field(default_factory=list)
def load_fixture(path: Path) -> ReplayFixture:
"""Load a single replay fixture from a JSON file.
Args:
path: Path to the fixture JSON file.
Returns:
A ReplayFixture with all fields populated.
Raises:
ValueError: If the fixture is missing required fields.
json.JSONDecodeError: If the file is not valid JSON.
"""
with open(path) as f:
data = json.load(f)
required = {"document_id", "document_type", "document_text", "expected_extraction"}
missing = required - set(data.keys())
if missing:
raise ValueError(f"Fixture {path.name} missing required fields: {missing}")
return ReplayFixture(
document_id=data["document_id"],
document_type=data["document_type"],
document_text=data["document_text"],
known_tickers=data.get("known_tickers", []),
expected_extraction=data["expected_extraction"],
metadata=data.get("metadata", {}),
source_path=str(path),
)
def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]:
"""Load all replay fixtures from the fixtures directory.
Args:
fixtures_dir: Override path to fixtures directory.
Defaults to tests/replay_fixtures/.
Returns:
List of loaded ReplayFixture objects, sorted by document_id.
"""
directory = fixtures_dir or FIXTURES_DIR
if not directory.is_dir():
logger.warning("Fixtures directory not found: %s", directory)
return []
fixtures: list[ReplayFixture] = []
for path in sorted(directory.glob("*.json")):
try:
fixture = load_fixture(path)
fixtures.append(fixture)
except (ValueError, json.JSONDecodeError) as exc:
logger.warning("Skipping invalid fixture %s: %s", path.name, exc)
logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory)
return fixtures
def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult:
"""Validate a fixture's expected extraction against the current schema.
This is the core deterministic test: the expected output must still
pass schema and semantic validation with the current code. If it
doesn't, either the fixture is stale or the schema has regressed.
Args:
fixture: The replay fixture to validate.
Returns:
A ReplayValidationResult indicating pass/fail.
"""
result = ReplayValidationResult(
fixture_id=fixture.document_id,
schema_version=get_schema_version(),
)
try:
report = validate_extraction(
fixture.expected_extraction,
document_text=fixture.document_text,
)
result.validation_report = report
result.schema_valid = report.valid
except Exception as exc: # noqa: BLE001
result.error = str(exc)
result.schema_valid = False
return result
def validate_all_fixtures(
fixtures_dir: Path | None = None,
) -> list[ReplayValidationResult]:
"""Load and validate all fixtures against the current schema.
Args:
fixtures_dir: Override path to fixtures directory.
Returns:
List of validation results, one per fixture.
"""
fixtures = load_all_fixtures(fixtures_dir)
return [validate_fixture(f) for f in fixtures]
def compare_extraction(
fixture: ReplayFixture,
actual_result: ExtractionResult,
) -> ReplayComparisonResult:
"""Compare a live extraction result against the fixture's expected output.
Checks structural alignment (same companies detected, same sentiments,
same catalyst types) rather than exact string equality, since LLM
outputs vary in wording across runs.
Args:
fixture: The replay fixture with expected output.
actual_result: The ExtractionResult from a live extraction.
Returns:
A ReplayComparisonResult with match details.
"""
expected = fixture.expected_result
comparison = ReplayComparisonResult(fixture_id=fixture.document_id)
# Company ticker sets
comparison.expected_companies = sorted(c.ticker for c in expected.companies)
comparison.actual_companies = sorted(c.ticker for c in actual_result.companies)
comparison.companies_match = (
set(comparison.expected_companies) == set(comparison.actual_companies)
)
# Sentiment by ticker
comparison.expected_sentiment_map = {
c.ticker: c.sentiment for c in expected.companies
}
comparison.actual_sentiment_map = {
c.ticker: c.sentiment for c in actual_result.companies
}
comparison.sentiment_match = (
comparison.expected_sentiment_map == comparison.actual_sentiment_map
)
# Catalyst type by ticker
comparison.expected_catalyst_map = {
c.ticker: c.catalyst_type for c in expected.companies
}
comparison.actual_catalyst_map = {
c.ticker: c.catalyst_type for c in actual_result.companies
}
comparison.catalyst_match = (
comparison.expected_catalyst_map == comparison.actual_catalyst_map
)
# Schema validity of actual result
actual_report = validate_extraction(
actual_result.model_dump(mode="json"),
document_text=fixture.document_text,
)
comparison.actual_schema_valid = actual_report.valid
if actual_report.warnings:
comparison.warnings = actual_report.warnings
if not comparison.companies_match:
comparison.warnings.append(
f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}"
)
return comparison
+316
View File
@@ -0,0 +1,316 @@
"""JSON schema definitions for document intelligence extraction.
Generates Ollama-compatible JSON schemas from Pydantic models so the
extraction contract stays in sync with the shared data models. Also
provides schema validation and semantic validation helpers.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
import re
from typing import Any
from pydantic import BaseModel, Field
from services.shared.schemas import (
CatalystType,
Sentiment,
)
SCHEMA_VERSION = "2.0.0"
# ---------------------------------------------------------------------------
# Pydantic model that mirrors the Ollama extraction output contract.
# This is the *response* shape we ask the model to produce — it intentionally
# omits server-side fields like document_id, source_credibility, and model
# metadata that are attached after extraction.
# ---------------------------------------------------------------------------
class CompanyExtractionItem(BaseModel):
"""Per-company extraction output expected from the model.
All fields are required (no defaults) so the generated JSON schema
forces the model to produce every field explicitly.
"""
ticker: str = Field(description="Stock ticker symbol mentioned in the document.")
company_name: str = Field(description="Full company name as referenced in the document.")
relevance: float = Field(
ge=0,
le=1,
description="How relevant the document is to this company. 0=tangential, 1=primary subject.",
)
sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.")
impact_score: float = Field(
ge=0,
le=1,
description="Estimated magnitude of impact. 0=negligible, 1=highly material.",
)
impact_horizon: str = Field(
description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus",
)
catalyst_type: CatalystType = Field(description="Primary catalyst category.")
key_facts: list[str] = Field(
description="Facts explicitly stated in the document. Do NOT infer or fabricate.",
)
risks: list[str] = Field(
description="Risks explicitly mentioned in the document.",
)
evidence_spans: list[str] = Field(
description="Short verbatim quotes from the document supporting the analysis.",
)
class ExtractionResult(BaseModel):
"""Top-level structured output the model must return.
All fields are required (no defaults) so the generated JSON schema
forces the model to produce every field explicitly.
"""
summary: str = Field(
description="A concise 1-3 sentence summary of the document's main point.",
)
companies: list[CompanyExtractionItem] = Field(
description="Per-company intelligence extracted from the document.",
)
macro_themes: list[str] = Field(
description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).",
)
novelty_score: float = Field(
ge=0,
le=1,
description="How novel or surprising the information is. 0=routine, 1=highly novel.",
)
confidence: float = Field(
ge=0,
le=1,
description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.",
)
extraction_warnings: list[str] = Field(
description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.",
)
# ---------------------------------------------------------------------------
# Schema generation
# ---------------------------------------------------------------------------
def generate_json_schema() -> dict[str, Any]:
"""Generate the JSON schema from the Pydantic model.
Returns a plain JSON Schema dict suitable for Ollama's ``format``
parameter. Pydantic ``$defs`` are inlined so the schema is
self-contained.
"""
raw = ExtractionResult.model_json_schema()
# Inline $defs so the schema is flat and Ollama-friendly
return _inline_defs(raw)
def get_schema_version() -> str:
"""Return the current schema version string."""
return SCHEMA_VERSION
# ---------------------------------------------------------------------------
# Validation helpers
# ---------------------------------------------------------------------------
class ValidationReport(BaseModel):
"""Result of validating a raw model response."""
valid: bool = False
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
parsed: ExtractionResult | None = None
def validate_extraction(
raw_json: str | dict[str, Any],
*,
document_text: str = "",
) -> ValidationReport:
"""Validate raw model output against the extraction schema.
Performs structural (JSON / Pydantic) validation followed by semantic
checks that catch hallucination indicators, cross-field inconsistencies,
and data-quality issues.
Args:
raw_json: Either a JSON string or an already-parsed dict.
document_text: Optional original document text used for evidence
span verification.
Returns:
A ``ValidationReport`` with parsed result on success.
"""
errors: list[str] = []
warnings: list[str] = []
# --- Parse JSON string if needed ---
if isinstance(raw_json, str):
try:
data = json.loads(raw_json)
except json.JSONDecodeError as exc:
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
else:
data = raw_json
if not isinstance(data, dict):
return ValidationReport(valid=False, errors=["Expected a JSON object at top level."])
# --- Pydantic structural validation ---
try:
result = ExtractionResult.model_validate(data)
except Exception as exc: # noqa: BLE001
return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"])
# --- Semantic checks ---
sem_errors, sem_warnings = _semantic_checks(result, document_text)
errors.extend(sem_errors)
warnings.extend(sem_warnings)
# Semantic errors make the report invalid — the caller should retry.
valid = len(errors) == 0
return ValidationReport(
valid=valid,
errors=errors,
warnings=warnings,
parsed=result,
)
# ---------------------------------------------------------------------------
# Known valid impact horizons
# ---------------------------------------------------------------------------
VALID_IMPACT_HORIZONS = frozenset({
"intraday",
"1d",
"1d_7d",
"1d_30d",
"30d_90d",
"90d_plus",
})
# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.)
_TICKER_RE = re.compile(r"^[A-Z]{1,5}$")
# Evidence span length bounds (characters)
_MIN_EVIDENCE_LEN = 8
_MAX_EVIDENCE_LEN = 500
# ---------------------------------------------------------------------------
# Semantic validation rules
# ---------------------------------------------------------------------------
def _semantic_checks(
result: ExtractionResult,
document_text: str = "",
) -> tuple[list[str], list[str]]:
"""Run semantic checks on a parsed extraction.
Returns a tuple of (errors, warnings). Errors are issues severe enough
to warrant a retry; warnings are informational.
"""
errors: list[str] = []
warnings: list[str] = []
# --- Top-level checks ---
if not result.summary:
warnings.append("empty_summary")
if result.confidence < 0.3 and len(result.companies) > 0:
warnings.append("low_confidence_with_companies")
# Duplicate tickers across company entries
tickers_seen: list[str] = []
for comp in result.companies:
if comp.ticker in tickers_seen:
errors.append(f"duplicate_ticker_{comp.ticker}")
tickers_seen.append(comp.ticker)
# --- Per-company checks ---
for comp in result.companies:
tag = comp.ticker or "unknown"
# Ticker format
if not comp.ticker:
errors.append("company_missing_ticker")
elif not _TICKER_RE.match(comp.ticker):
warnings.append(f"invalid_ticker_format_{tag}")
# Impact horizon must be a known value
if comp.impact_horizon not in VALID_IMPACT_HORIZONS:
errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}")
# Evidence spans
if not comp.evidence_spans:
warnings.append(f"no_evidence_spans_for_{tag}")
else:
for idx, span in enumerate(comp.evidence_spans):
if len(span) < _MIN_EVIDENCE_LEN:
warnings.append(f"evidence_span_too_short_for_{tag}_{idx}")
if len(span) > _MAX_EVIDENCE_LEN:
warnings.append(f"evidence_span_too_long_for_{tag}_{idx}")
# Cross-field: high impact but no facts
if not comp.key_facts and comp.impact_score > 0.5:
warnings.append(f"high_impact_no_facts_for_{tag}")
# Cross-field: very low relevance
if comp.relevance < 0.2:
warnings.append(f"very_low_relevance_for_{tag}")
# Cross-field: strong sentiment but low impact
if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1:
warnings.append(f"strong_sentiment_low_impact_for_{tag}")
# --- Evidence grounding check (when source text is available) ---
if document_text:
doc_lower = document_text.lower()
for comp in result.companies:
for idx, span in enumerate(comp.evidence_spans):
if span.lower() not in doc_lower:
warnings.append(
f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}"
)
return errors, warnings
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained."""
defs = schema.pop("$defs", {})
return _resolve_refs(schema, defs)
def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any:
"""Walk the schema tree and replace ``$ref`` pointers with their definitions."""
if isinstance(node, dict):
if "$ref" in node:
ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem"
ref_name = ref_path.rsplit("/", 1)[-1]
if ref_name in defs:
resolved = defs[ref_name].copy()
# The resolved def may itself contain refs
return _resolve_refs(resolved, defs)
return node # unresolvable ref, leave as-is
return {k: _resolve_refs(v, defs) for k, v in node.items()}
if isinstance(node, list):
return [_resolve_refs(item, defs) for item in node]
return node
+291 -1
View File
@@ -1 +1,291 @@
"""Extraction worker - sends documents to Ollama for structured intelligence extraction."""
"""Extraction worker - sends documents to Ollama for structured intelligence extraction.
Orchestrates the full extraction pipeline for a single document:
1. Calls OllamaClient to get structured extraction
2. Uploads prompts, raw outputs, and validation reports to MinIO
3. Persists the final intelligence object and per-company impact records to PostgreSQL
4. Updates document status
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
import asyncpg
from minio import Minio
from services.extractor.client import ExtractionResponse
from services.extractor.metrics import collect_metrics, persist_metrics
from services.shared.metadata import (
persist_document_impact,
persist_document_intelligence,
update_document_status,
)
from services.shared.storage import (
upload_extraction_intelligence,
upload_extraction_prompt,
upload_extraction_raw_output,
upload_extraction_validation,
)
from services.shared.logging import Span
from services.shared.metrics import (
EXTRACTION_ATTEMPTS,
EXTRACTION_CONFIDENCE,
EXTRACTION_DURATION,
EXTRACTION_JOBS_TOTAL,
EXTRACTION_RETRIES,
EXTRACTION_TOKEN_ESTIMATE,
EXTRACTION_VALIDATION_ERRORS,
)
logger = logging.getLogger("extractor_worker")
@dataclass
class ExtractionPersistResult:
"""Result of persisting an extraction to storage and database."""
intelligence_id: str | None = None
prompt_ref: str | None = None
raw_output_ref: str | None = None
validation_ref: str | None = None
intelligence_ref: str | None = None
impact_ids: list[str] | None = None
metrics_id: str | None = None
success: bool = False
async def persist_extraction(
*,
pool: asyncpg.Pool,
minio_client: Minio,
document_id: str,
ticker: str,
extraction_response: ExtractionResponse,
company_id_map: dict[str, str] | None = None,
source_credibility: float = 0.5,
timestamp: datetime | None = None,
document_text_length: int = 0,
) -> ExtractionPersistResult:
"""Persist all extraction artifacts to MinIO and PostgreSQL.
Uploads prompts, raw model outputs, validation reports, and the final
intelligence object to MinIO. Persists the intelligence record and
per-company impact records to PostgreSQL. Updates document status.
Also collects and persists model performance metrics.
Args:
pool: PostgreSQL connection pool.
minio_client: MinIO client.
document_id: UUID of the source document.
ticker: Primary ticker for path construction.
extraction_response: Full response from OllamaClient.extract().
company_id_map: Optional mapping of ticker -> company UUID for impact records.
source_credibility: Credibility score to attach to the intelligence record.
timestamp: Override timestamp for MinIO paths (defaults to UTC now).
document_text_length: Length of the input document text for token estimation.
Returns:
ExtractionPersistResult with references to all persisted artifacts.
"""
ts = timestamp or datetime.now(timezone.utc)
result = ExtractionPersistResult()
company_id_map = company_id_map or {}
# 1. Upload prompt metadata to MinIO
prompt_payload = json.dumps({
"prompt_metadata": extraction_response.prompt_metadata,
"model": extraction_response.model,
}, indent=2).encode()
result.prompt_ref = upload_extraction_prompt(
minio_client, ticker, document_id, prompt_payload, timestamp=ts,
)
# 2. Upload raw outputs for each attempt
attempts_data: list[dict[str, object]] = []
for idx, attempt in enumerate(extraction_response.attempts):
attempt_record: dict[str, object] = {
"attempt_index": idx,
"raw_output": attempt.raw_output,
"error": attempt.error,
"duration_ms": attempt.duration_ms,
"model": attempt.model,
"retryable": attempt.retryable,
}
if attempt.validation:
attempt_record["validation"] = {
"valid": attempt.validation.valid,
"errors": attempt.validation.errors,
"warnings": attempt.validation.warnings,
}
attempts_data.append(attempt_record)
raw_output_payload = json.dumps({
"document_id": document_id,
"attempts": attempts_data,
"total_duration_ms": extraction_response.total_duration_ms,
"success": extraction_response.success,
}, indent=2).encode()
result.raw_output_ref = upload_extraction_raw_output(
minio_client, ticker, document_id, raw_output_payload, timestamp=ts,
)
# 3. Upload validation report
final_attempt = extraction_response.attempts[-1] if extraction_response.attempts else None
validation_payload = json.dumps({
"document_id": document_id,
"success": extraction_response.success,
"attempt_count": len(extraction_response.attempts),
"final_validation": {
"valid": final_attempt.validation.valid if final_attempt and final_attempt.validation else False,
"errors": final_attempt.validation.errors if final_attempt and final_attempt.validation else [],
"warnings": final_attempt.validation.warnings if final_attempt and final_attempt.validation else [],
} if final_attempt else None,
}, indent=2).encode()
result.validation_ref = upload_extraction_validation(
minio_client, ticker, document_id, validation_payload, timestamp=ts,
)
# 4. Determine validation status and persist intelligence
if extraction_response.success and extraction_response.result:
extraction = extraction_response.result
validation_status = "valid"
validation_errors: list[str] = []
# Upload final intelligence object to MinIO
intelligence_payload = json.dumps(
extraction.model_dump(mode="json"), indent=2,
).encode()
result.intelligence_ref = upload_extraction_intelligence(
minio_client, ticker, document_id, intelligence_payload, timestamp=ts,
)
# Persist to PostgreSQL
intel_id = await persist_document_intelligence(
pool,
document_id=document_id,
summary=extraction.summary,
macro_themes=extraction.macro_themes,
novelty_score=extraction.novelty_score,
source_credibility=source_credibility,
extraction_warnings=extraction.extraction_warnings,
confidence=extraction.confidence,
model_provider="ollama",
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
raw_output_ref=result.raw_output_ref,
prompt_ref=result.prompt_ref,
validation_status=validation_status,
validation_errors=validation_errors,
retry_count=len(extraction_response.attempts) - 1,
)
result.intelligence_id = intel_id
# Persist per-company impact records
result.impact_ids = []
for company in extraction.companies:
cid = company_id_map.get(company.ticker)
if not cid:
logger.warning(
"No company_id for ticker %s in doc %s, skipping impact record",
company.ticker, document_id,
)
continue
impact_id = await persist_document_impact(
pool,
intelligence_id=intel_id,
company_id=cid,
ticker=company.ticker,
relevance=company.relevance,
sentiment=company.sentiment,
impact_score=company.impact_score,
impact_horizon=company.impact_horizon,
catalyst_type=company.catalyst_type,
key_facts=company.key_facts,
risks=company.risks,
evidence_spans=company.evidence_spans,
)
result.impact_ids.append(impact_id)
await update_document_status(pool, document_id=document_id, status="extracted")
result.success = True
logger.info(
"Extraction persisted for doc %s: intel=%s, impacts=%d",
document_id, intel_id, len(result.impact_ids),
)
else:
# Failed extraction — still persist the attempt data
all_errors: list[str] = []
for attempt in extraction_response.attempts:
if attempt.error:
all_errors.append(attempt.error)
intel_id = await persist_document_intelligence(
pool,
document_id=document_id,
summary="",
macro_themes=[],
novelty_score=0.0,
source_credibility=source_credibility,
extraction_warnings=["extraction_failed"],
confidence=0.0,
model_provider="ollama",
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
raw_output_ref=result.raw_output_ref,
prompt_ref=result.prompt_ref,
validation_status="failed",
validation_errors=all_errors,
retry_count=len(extraction_response.attempts),
)
result.intelligence_id = intel_id
await update_document_status(pool, document_id=document_id, status="extraction_failed")
logger.warning(
"Extraction failed for doc %s after %d attempts: %s",
document_id, len(extraction_response.attempts), "; ".join(all_errors),
)
# Collect and persist model performance metrics
try:
metrics = collect_metrics(
extraction_response,
document_id=document_id,
ticker=ticker,
document_text_length=document_text_length,
)
metrics.recorded_at = ts
metrics_id = await persist_metrics(pool, metrics)
result.metrics_id = metrics_id
except Exception:
logger.exception("Failed to persist extraction metrics for doc %s", document_id)
# Prometheus metrics
EXTRACTION_ATTEMPTS.inc(len(extraction_response.attempts))
EXTRACTION_DURATION.observe(extraction_response.total_duration_ms / 1000.0)
retry_count = max(0, len(extraction_response.attempts) - 1)
if retry_count > 0:
EXTRACTION_RETRIES.inc(retry_count)
if extraction_response.success:
EXTRACTION_JOBS_TOTAL.labels(status="success").inc()
if extraction_response.result:
EXTRACTION_CONFIDENCE.observe(extraction_response.result.confidence)
else:
EXTRACTION_JOBS_TOTAL.labels(status="failed").inc()
# Count validation errors from final attempt
final = extraction_response.attempts[-1] if extraction_response.attempts else None
if final and final.validation and final.validation.errors:
EXTRACTION_VALIDATION_ERRORS.inc(len(final.validation.errors))
# Token estimates
if document_text_length > 0:
EXTRACTION_TOKEN_ESTIMATE.labels(direction="input").inc(document_text_length // 4)
if final and final.raw_output:
EXTRACTION_TOKEN_ESTIMATE.labels(direction="output").inc(len(final.raw_output) // 4)
return result
+151 -80
View File
@@ -1,47 +1,50 @@
"""Ingestion worker - processes jobs from the ingestion queue."""
import asyncio
import hashlib
import io
import json
import logging
from datetime import datetime
import asyncpg
import redis.asyncio as aioredis
from minio import Minio
from services.adapters.base import AdapterResult
from services.adapters.filings_adapter import FilingsAdapter
from services.adapters.market_adapter import MarketDataAdapter
from services.adapters.news_adapter import NewsApiAdapter
from services.adapters.broker_adapter import AlpacaBrokerAdapter, TradingMode
from services.adapters.filings_adapter import SECEdgarAdapter
from services.adapters.market_adapter import PolygonMarketAdapter
from services.adapters.news_adapter import PolygonNewsAdapter
from services.adapters.web_scrape_adapter import WebScrapeAdapter
from services.shared.config import load_config
from services.shared.db import get_minio, get_pg_pool, get_redis
from services.shared.dedupe import dedupe_items, mark_as_seen
from services.shared.metadata import (
persist_ingestion_items,
record_retrieval_failure,
reset_source_retry_state,
)
from services.shared.redis_keys import (
QUEUE_INGESTION,
QUEUE_PARSING,
dedupe_key,
queue_key,
)
from services.shared.logging import Span, extract_trace_context, inject_trace_context, new_trace_id, set_trace_context, setup_logging
from services.shared.metrics import (
ACTIVE_JOBS,
INGESTION_ADAPTER_DURATION,
INGESTION_ERRORS,
INGESTION_ITEMS_DEDUPED,
INGESTION_ITEMS_FETCHED,
INGESTION_ITEMS_NEW,
INGESTION_JOBS_TOTAL,
)
from services.shared.storage import (
bucket_for_source,
ensure_buckets,
upload_raw_artifact,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ingestion_worker")
BUCKET_MAP = {
"market_api": "stonks-raw-market",
"news_api": "stonks-raw-news",
"filings_api": "stonks-raw-filings",
"broker": "stonks-raw-market",
}
def build_storage_path(source_type: str, ticker: str, doc_id: str) -> str:
now = datetime.utcnow()
return f"{source_type}/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/raw.json"
async def store_raw_artifact(minio_client: Minio, bucket: str, path: str, data: bytes):
minio_client.put_object(bucket, path, io.BytesIO(data), len(data), content_type="application/json")
async def process_job(
job: dict,
@@ -55,9 +58,11 @@ async def process_job(
source_id = job["source_id"]
config = job.get("config", {})
set_trace_context(trace_id=job.get("_trace_id") or new_trace_id())
adapter = adapters.get(source_type)
if not adapter:
logger.warning(f"No adapter for source_type={source_type}")
logger.warning("No adapter for source_type=%s", source_type)
return
# Record ingestion run
@@ -68,25 +73,37 @@ async def process_job(
)
try:
result: AdapterResult = await adapter.fetch(ticker, config)
with Span("adapter_fetch", ticker=ticker, source_type=source_type):
with INGESTION_ADAPTER_DURATION.labels(source_type=source_type).time():
result: AdapterResult = await adapter.fetch(ticker, config)
if result.error:
await pool.execute(
"UPDATE ingestion_runs SET status='failed', error_message=$2, completed_at=NOW() WHERE id=$1",
run_id, result.error,
INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="error").inc()
await record_retrieval_failure(
pool,
run_id=str(run_id),
source_id=source_id,
error_message=result.error,
)
return
# Store raw payload
bucket = BUCKET_MAP.get(source_type, "stonks-raw-market")
storage_path = build_storage_path(source_type, ticker, str(run_id))
await store_raw_artifact(minio_client, bucket, storage_path, result.raw_payload)
# Store raw payload in MinIO
bucket = bucket_for_source(source_type)
artifact_type = "raw_html" if source_type == "web_scrape" else "raw_json"
storage_uri = upload_raw_artifact(
minio_client,
source_type=source_type,
ticker=ticker,
document_id=str(run_id),
data=result.raw_payload,
artifact_type=artifact_type,
)
# Dedupe check
# Dedupe check on the overall payload hash
if result.content_hash:
already_seen = await rds.get(dedupe_key(result.content_hash))
if already_seen:
logger.info(f"Duplicate content for {ticker}, skipping")
logger.info("Duplicate content for %s, skipping", ticker)
await pool.execute(
"UPDATE ingestion_runs SET status='completed', items_fetched=$2, items_new=0, completed_at=NOW() WHERE id=$1",
run_id, len(result.items),
@@ -94,72 +111,126 @@ async def process_job(
return
await rds.set(dedupe_key(result.content_hash), "1", ex=86400)
new_items = 0
for item in result.items:
item_json = json.dumps(item)
item_hash = hashlib.sha256(item_json.encode()).hexdigest()
# Cross-source dedupe on individual document items (news, filings, web_scrape)
items_to_persist = result.items
deduped_count = 0
if source_type not in ("market_api", "broker"):
items_to_persist, dup_items = await dedupe_items(pool, rds, result.items)
deduped_count = len(dup_items)
if deduped_count:
INGESTION_ITEMS_DEDUPED.labels(source_type=source_type).inc(deduped_count)
logger.info(
"Deduped %d/%d items for %s/%s",
deduped_count, len(result.items), ticker, source_type,
)
# Check if document already exists
exists = await pool.fetchval("SELECT 1 FROM documents WHERE content_hash = $1", item_hash)
if exists:
continue
# Persist metadata via the unified metadata module
new_items, new_ids = await persist_ingestion_items(
pool,
source_type=source_type,
ticker=ticker,
company_id=job.get("company_id"),
items=items_to_persist,
storage_ref=storage_uri,
adapter_metadata=result.metadata,
content_hash=result.content_hash,
)
title = item.get("title", item.get("name", ""))
url = item.get("url", item.get("link", ""))
published = item.get("publishedAt", item.get("published_at"))
# Enqueue new document items for parsing (not market/broker)
if source_type not in ("market_api", "broker"):
for doc_id in new_ids:
await rds.rpush(queue_key(QUEUE_PARSING), json.dumps(inject_trace_context({
"document_id": doc_id,
"ticker": ticker,
"source_type": source_type,
})))
doc_id = await pool.fetchval(
"""INSERT INTO documents (document_type, source_type, publisher, url, title, published_at, content_hash, raw_storage_ref, status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'ingested')
RETURNING id""",
"article" if source_type == "news_api" else "filing" if source_type == "filings_api" else "article",
source_type,
item.get("source", {}).get("name", "") if isinstance(item.get("source"), dict) else str(item.get("source", "")),
url, title,
datetime.fromisoformat(published.replace("Z", "+00:00")) if published else None,
item_hash,
f"s3://{bucket}/{storage_path}",
)
# Mark newly persisted documents in Redis for fast future dedupe
for item, doc_id in zip(items_to_persist, new_ids):
await mark_as_seen(
rds,
content_hash=item.get("content_hash", ""),
canonical_url=item.get("canonical_url"),
document_id=doc_id,
)
# Enqueue for parsing
await rds.rpush(queue_key(QUEUE_PARSING), json.dumps({
"document_id": str(doc_id),
"ticker": ticker,
"source_type": source_type,
"url": url,
}))
new_items += 1
# Link duplicate documents to this company if not already linked
company_id = job.get("company_id")
if company_id and deduped_count:
from services.shared.metadata import persist_document_company_mention
for dup in dup_items:
existing_id = dup.get("_dedupe_existing_id")
if existing_id:
try:
await persist_document_company_mention(
pool,
document_id=existing_id,
company_id=company_id,
ticker=ticker,
mention_type="cross_source",
)
except Exception:
# Duplicate mention link — safe to ignore
pass
await pool.execute(
"UPDATE ingestion_runs SET status='completed', items_fetched=$2, items_new=$3, completed_at=NOW() WHERE id=$1",
run_id, len(result.items), new_items,
)
logger.info(f"Ingested {ticker}/{source_type}: {len(result.items)} fetched, {new_items} new")
# Clear any accumulated retry backoff after success
await reset_source_retry_state(pool, source_id)
INGESTION_ITEMS_FETCHED.labels(source_type=source_type).inc(len(result.items))
INGESTION_ITEMS_NEW.labels(source_type=source_type).inc(new_items)
INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="success").inc()
logger.info(
"Ingested %s/%s: %d fetched, %d new",
ticker, source_type, len(result.items), new_items,
extra={"ticker": ticker, "source_type": source_type, "count": new_items},
)
except Exception as e:
logger.error(f"Ingestion error for {ticker}: {e}")
await pool.execute(
"UPDATE ingestion_runs SET status='failed', error_message=$2, completed_at=NOW() WHERE id=$1",
run_id, str(e),
INGESTION_ERRORS.labels(source_type=source_type).inc()
INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="error").inc()
logger.error(
"Ingestion error for %s: %s", ticker, e,
extra={"ticker": ticker, "source_type": source_type, "error": str(e)},
)
await record_retrieval_failure(
pool,
run_id=str(run_id),
source_id=source_id,
error_message=str(e),
)
async def main():
config = load_config()
pool = await get_pg_pool(config)
rds = get_redis(config)
minio_client = get_minio(config)
cfg = load_config()
setup_logging("ingestion_worker", level=cfg.log_level, json_output=cfg.json_logs)
pool = await get_pg_pool(cfg)
rds = get_redis(cfg)
minio_client = get_minio(cfg)
# Ensure all required buckets exist
ensure_buckets(minio_client)
adapters = {
"market_api": MarketDataAdapter(
api_key=config.broker.api_key or "",
"market_api": PolygonMarketAdapter(
api_key=cfg.market_data.api_key,
base_url=cfg.market_data.base_url,
),
"news_api": PolygonNewsAdapter(
api_key=cfg.market_data.api_key,
base_url="https://api.polygon.io",
),
"news_api": NewsApiAdapter(
api_key="",
base_url="https://newsapi.org",
"filings_api": SECEdgarAdapter(),
"web_scrape": WebScrapeAdapter(),
"broker": AlpacaBrokerAdapter(
api_key=cfg.broker.api_key or "",
api_secret=cfg.broker.api_secret or "",
mode=TradingMode.LIVE if cfg.broker.mode == "live" else TradingMode.PAPER,
base_url=cfg.broker.base_url,
),
"filings_api": FilingsAdapter(),
}
logger.info("Ingestion worker started")
+1 -1
View File
@@ -1 +1 @@
# Lake Publisher - transforms operational data into analytical fact datasets
"""Lake publisher — writes partitioned Parquet facts to MinIO for Trino/Superset."""
+39
View File
@@ -0,0 +1,39 @@
"""Helpers for enqueuing lake publish jobs from upstream workers.
Other services import these helpers to push jobs onto the QUEUE_LAKE_PUBLISH
Redis queue. The lake publisher worker (jobs.py) consumes them.
Usage:
await enqueue_lake_job(rds, "document", document_id)
await enqueue_lake_job(rds, "trade_order", order_id)
await enqueue_lake_job(rds, "bulk_documents", since=cutoff.isoformat())
"""
from __future__ import annotations
import json
import redis.asyncio as aioredis
from services.shared.redis_keys import QUEUE_LAKE_PUBLISH, queue_key
async def enqueue_lake_job(
rds: aioredis.Redis,
job_type: str,
entity_id: str = "",
since: str | None = None,
) -> None:
"""Push a lake publish job onto the Redis queue.
Args:
rds: Async Redis client.
job_type: One of the supported job types (document, document_extraction,
market_snapshot, trade_order, trade_fill, positions_snapshot,
pnl_snapshot, bulk_documents, bulk_extractions).
entity_id: UUID or identifier for the entity to publish.
since: ISO datetime string for bulk jobs (cutoff timestamp).
"""
payload: dict[str, str] = {"job_type": job_type, "entity_id": entity_id}
if since:
payload["since"] = since
await rds.rpush(queue_key(QUEUE_LAKE_PUBLISH), json.dumps(payload)) # type: ignore[misc]
+420
View File
@@ -0,0 +1,420 @@
"""Iceberg table creation and metadata management for analytical datasets.
Manages Iceberg tables in Trino's Iceberg catalog, providing:
- Table creation with proper schemas and partition specs
- Schema synchronization between PyArrow definitions and Iceberg tables
- Table metadata inspection (existence checks, schema retrieval, partition listing)
The Iceberg catalog complements the existing Hive-compatible partition layout.
Parquet files written by the lake publisher are stored in the same MinIO paths,
but Iceberg metadata enables schema evolution, snapshot isolation, and better
partition pruning via Trino's Iceberg connector.
Requirements: 9.4, 9.5, 10.1, N4, N6
Design ref: Section 5.3 (Lakehouse model), Section 4.12 (SQL Query Engine)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
import pyarrow as pa
from trino.dbapi import connect as trino_connect
from services.lake_publisher.partitions import (
LAKEHOUSE_BUCKET,
TABLE_PARTITIONS,
WAREHOUSE_PREFIX,
PartitionSpec,
)
from services.lake_publisher.worker import (
COMPANY_EVENTS_SCHEMA,
DOCUMENTS_SCHEMA,
DOCUMENT_EXTRACTIONS_SCHEMA,
MARKET_BARS_SCHEMA,
MARKET_QUOTES_SCHEMA,
MODEL_PERFORMANCE_SCHEMA,
PNL_DAILY_SCHEMA,
POSITIONS_DAILY_SCHEMA,
PREDICTION_VS_OUTCOME_SCHEMA,
TRADE_FILLS_SCHEMA,
TRADE_ORDERS_SCHEMA,
TRADE_SIGNALS_SCHEMA,
)
logger = logging.getLogger(__name__)
ICEBERG_CATALOG = "iceberg"
ICEBERG_SCHEMA = "stonks"
def _get_iceberg_catalog() -> str:
"""Return the Iceberg catalog name from env or default."""
import os
return os.getenv("TRINO_ICEBERG_CATALOG", ICEBERG_CATALOG)
# Map PyArrow types to Trino/Iceberg SQL types.
_ARROW_TO_TRINO: dict[str, str] = {
"string": "VARCHAR",
"utf8": "VARCHAR",
"large_string": "VARCHAR",
"large_utf8": "VARCHAR",
"float64": "DOUBLE",
"double": "DOUBLE",
"float32": "REAL",
"float": "REAL",
"int8": "TINYINT",
"int16": "SMALLINT",
"int32": "INTEGER",
"int64": "BIGINT",
"bool": "BOOLEAN",
"date32": "DATE",
"date32[day]": "DATE",
"date64": "DATE",
}
def _arrow_type_to_trino(arrow_type: pa.DataType) -> str:
"""Convert a PyArrow data type to a Trino SQL type string."""
type_str = str(arrow_type)
# Handle timestamp types (with or without timezone)
if type_str.startswith("timestamp"):
if "tz=" in type_str:
return "TIMESTAMP(6) WITH TIME ZONE"
return "TIMESTAMP(6)"
# Direct lookup
result = _ARROW_TO_TRINO.get(type_str)
if result:
return result
# Fallback for type IDs
if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type):
return "VARCHAR"
if pa.types.is_floating(arrow_type):
return "DOUBLE"
if pa.types.is_integer(arrow_type):
return "BIGINT"
if pa.types.is_boolean(arrow_type):
return "BOOLEAN"
if pa.types.is_date(arrow_type):
return "DATE"
if pa.types.is_timestamp(arrow_type):
return "TIMESTAMP(6) WITH TIME ZONE"
raise ValueError(f"Unsupported PyArrow type for Iceberg DDL: {arrow_type}")
# Registry mapping table names to their PyArrow schemas.
TABLE_SCHEMAS: dict[str, pa.Schema] = {
"market_bars": MARKET_BARS_SCHEMA,
"market_quotes": MARKET_QUOTES_SCHEMA,
"company_events": COMPANY_EVENTS_SCHEMA,
"documents": DOCUMENTS_SCHEMA,
"document_extractions": DOCUMENT_EXTRACTIONS_SCHEMA,
"trade_signals": TRADE_SIGNALS_SCHEMA,
"trade_orders": TRADE_ORDERS_SCHEMA,
"trade_fills": TRADE_FILLS_SCHEMA,
"positions_daily": POSITIONS_DAILY_SCHEMA,
"pnl_daily": PNL_DAILY_SCHEMA,
"prediction_vs_outcome": PREDICTION_VS_OUTCOME_SCHEMA,
"model_performance": MODEL_PERFORMANCE_SCHEMA,
}
@dataclass(frozen=True)
class IcebergTableDef:
"""Definition for an Iceberg table derived from PyArrow schema + partition spec."""
table_name: str
schema: pa.Schema
partition_spec: PartitionSpec
@property
def qualified_name(self) -> str:
return f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.{self.table_name}"
@property
def location(self) -> str:
return f"s3a://{LAKEHOUSE_BUCKET}/{WAREHOUSE_PREFIX}/{self.table_name}/"
def column_defs_sql(self) -> list[str]:
"""Generate SQL column definitions from the PyArrow schema.
Partition columns are included in the column list (Iceberg stores them
in the data files, unlike Hive external tables).
"""
cols: list[str] = []
for i in range(len(self.schema)):
name = self.schema.field(i).name
arrow_type = self.schema.field(i).type
trino_type = _arrow_type_to_trino(arrow_type)
cols.append(f" {name} {trino_type}")
return cols
def partition_keys_sql(self) -> str:
"""Generate the partitioning clause for CREATE TABLE."""
keys = list(self.partition_spec.all_keys)
if not keys:
return ""
quoted = ", ".join(f"'{k}'" for k in keys)
return f"partitioning = ARRAY[{quoted}]"
def create_table_sql(self) -> str:
"""Generate a CREATE TABLE IF NOT EXISTS statement for Trino's Iceberg catalog."""
col_lines = ",\n".join(self.column_defs_sql())
with_clauses = [
"format = 'PARQUET'",
f"location = '{self.location}'",
]
part_sql = self.partition_keys_sql()
if part_sql:
with_clauses.append(part_sql)
with_block = ",\n ".join(with_clauses)
return (
f"CREATE TABLE IF NOT EXISTS {self.qualified_name} (\n"
f"{col_lines}\n"
f") WITH (\n"
f" {with_block}\n"
f")"
)
def get_all_table_defs() -> list[IcebergTableDef]:
"""Build IcebergTableDef for every registered analytical table."""
defs: list[IcebergTableDef] = []
for table_name, partition_spec in TABLE_PARTITIONS.items():
schema = TABLE_SCHEMAS.get(table_name)
if schema is None:
logger.warning("No PyArrow schema for table %s, skipping", table_name)
continue
defs.append(IcebergTableDef(
table_name=table_name,
schema=schema,
partition_spec=partition_spec,
))
return defs
def get_table_def(table_name: str) -> IcebergTableDef:
"""Get the IcebergTableDef for a single table by name."""
if table_name not in TABLE_PARTITIONS:
raise ValueError(f"Unknown table: {table_name}")
schema = TABLE_SCHEMAS.get(table_name)
if schema is None:
raise ValueError(f"No PyArrow schema registered for table: {table_name}")
return IcebergTableDef(
table_name=table_name,
schema=schema,
partition_spec=TABLE_PARTITIONS[table_name],
)
@dataclass
class IcebergManager:
"""Manages Iceberg tables via Trino's Iceberg catalog.
Provides table creation, existence checks, schema inspection,
and metadata operations against the Trino Iceberg connector.
"""
host: str = "localhost"
port: int = 8080
user: str = "stonks"
catalog: str = ICEBERG_CATALOG
schema: str = ICEBERG_SCHEMA
def _get_connection(self) -> Any:
"""Create a Trino DBAPI connection."""
return trino_connect(
host=self.host,
port=self.port,
user=self.user,
catalog=self.catalog,
schema=self.schema,
)
def _execute(self, sql: str) -> list[list[Any]]:
"""Execute a SQL statement and return all rows."""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql)
return cursor.fetchall()
finally:
conn.close()
def _execute_no_fetch(self, sql: str) -> None:
"""Execute a DDL statement that returns no rows."""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql)
# DDL statements in Trino still need fetchall to complete
try:
cursor.fetchall()
except Exception:
pass
finally:
conn.close()
def ensure_schema(self) -> None:
"""Create the Iceberg schema if it doesn't exist."""
sql = f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.schema}"
logger.info("Ensuring Iceberg schema: %s.%s", self.catalog, self.schema)
self._execute_no_fetch(sql)
def table_exists(self, table_name: str) -> bool:
"""Check if an Iceberg table exists."""
sql = (
f"SELECT table_name FROM {self.catalog}.information_schema.tables "
f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}'"
)
rows = self._execute(sql)
return len(rows) > 0
def create_table(self, table_name: str) -> bool:
"""Create a single Iceberg table if it doesn't exist.
Returns True if the table was created, False if it already existed.
"""
table_def = get_table_def(table_name)
ddl = table_def.create_table_sql()
logger.info("Creating Iceberg table: %s", table_def.qualified_name)
self._execute_no_fetch(ddl)
logger.info("Iceberg table ready: %s", table_def.qualified_name)
return True
def create_all_tables(self) -> dict[str, bool]:
"""Create all registered Iceberg tables.
Returns a dict mapping table_name -> True (created) or False (error).
"""
self.ensure_schema()
results: dict[str, bool] = {}
for table_def in get_all_table_defs():
try:
self.create_table(table_def.table_name)
results[table_def.table_name] = True
except Exception:
logger.exception("Failed to create Iceberg table: %s", table_def.table_name)
results[table_def.table_name] = False
return results
def get_table_schema(self, table_name: str) -> list[dict[str, str]]:
"""Retrieve the column schema of an Iceberg table from Trino.
Returns a list of dicts with 'column_name', 'data_type', and 'is_nullable'.
"""
sql = (
f"SELECT column_name, data_type, is_nullable "
f"FROM {self.catalog}.information_schema.columns "
f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}' "
f"ORDER BY ordinal_position"
)
rows = self._execute(sql)
return [
{"column_name": r[0], "data_type": r[1], "is_nullable": r[2]}
for r in rows
]
def get_table_snapshots(self, table_name: str) -> list[dict[str, Any]]:
"""List Iceberg snapshots for a table (useful for auditing and rollback).
Returns snapshot metadata from Trino's $snapshots metadata table.
"""
qualified = f"{self.catalog}.{self.schema}.{table_name}"
sql = f'SELECT * FROM "{qualified}$snapshots"'
try:
rows = self._execute(sql)
return [{"snapshot_id": r[0], "parent_id": r[1], "operation": r[2],
"manifest_list": r[3], "summary": r[4]} for r in rows]
except Exception:
logger.debug("Could not read snapshots for %s (table may be empty)", table_name)
return []
def get_table_partitions(self, table_name: str) -> list[dict[str, Any]]:
"""List partition values for an Iceberg table.
Returns partition metadata from Trino's $partitions metadata table.
"""
qualified = f"{self.catalog}.{self.schema}.{table_name}"
sql = f'SELECT * FROM "{qualified}$partitions"'
try:
rows = self._execute(sql)
return [{"row": r} for r in rows]
except Exception:
logger.debug("Could not read partitions for %s (table may be empty)", table_name)
return []
def list_tables(self) -> list[str]:
"""List all tables in the Iceberg schema."""
sql = (
f"SELECT table_name FROM {self.catalog}.information_schema.tables "
f"WHERE table_schema = '{self.schema}' ORDER BY table_name"
)
rows = self._execute(sql)
return [r[0] for r in rows]
def drop_table(self, table_name: str) -> None:
"""Drop an Iceberg table (for testing/reset purposes)."""
qualified = f"{self.catalog}.{self.schema}.{table_name}"
logger.warning("Dropping Iceberg table: %s", qualified)
self._execute_no_fetch(f"DROP TABLE IF EXISTS {qualified}")
def sync_table_schema(self, table_name: str) -> list[str]:
"""Compare the expected PyArrow schema with the actual Iceberg table schema.
If columns are missing from the Iceberg table, adds them via ALTER TABLE.
Returns a list of columns that were added.
This supports forward-only schema evolution — columns are never dropped.
"""
table_def = get_table_def(table_name)
existing = self.get_table_schema(table_name)
existing_names = {col["column_name"] for col in existing}
added: list[str] = []
qualified = table_def.qualified_name
for i in range(len(table_def.schema)):
col_name = table_def.schema.field(i).name
if col_name not in existing_names:
trino_type = _arrow_type_to_trino(table_def.schema.field(i).type)
alter_sql = f"ALTER TABLE {qualified} ADD COLUMN {col_name} {trino_type}"
logger.info("Adding column %s to %s", col_name, qualified)
self._execute_no_fetch(alter_sql)
added.append(col_name)
return added
def sync_all_schemas(self) -> dict[str, list[str]]:
"""Sync schemas for all registered tables. Returns table_name -> added columns."""
results: dict[str, list[str]] = {}
for table_def in get_all_table_defs():
try:
if self.table_exists(table_def.table_name):
added = self.sync_table_schema(table_def.table_name)
results[table_def.table_name] = added
else:
logger.info("Table %s doesn't exist yet, skipping sync", table_def.table_name)
results[table_def.table_name] = []
except Exception:
logger.exception("Failed to sync schema for %s", table_def.table_name)
results[table_def.table_name] = []
return results
def create_iceberg_manager_from_config(
host: str = "localhost",
port: int = 8080,
user: str = "stonks",
) -> IcebergManager:
"""Factory that creates an IcebergManager from explicit connection params."""
return IcebergManager(host=host, port=port, user=user)
+673
View File
@@ -0,0 +1,673 @@
"""Lake publisher async job runner — transforms operational data into analytical facts.
Reads jobs from the QUEUE_LAKE_PUBLISH Redis queue, queries PostgreSQL for
operational records, and publishes them as partitioned Parquet files to MinIO
via the existing publish_* functions in worker.py.
Job message format:
{"job_type": "<table_name>", "entity_id": "<uuid or ticker>", "dt": "2026-04-11T..."}
Supported job types:
- document: publish a single document metadata fact
- document_extraction: publish extraction facts for a document
- market_snapshot: publish market bars/quotes from a snapshot
- trade_order: publish an order fact
- trade_fill: publish fill facts for an order
- positions_snapshot: publish daily position snapshots for a broker account
- pnl_snapshot: publish daily PnL for a broker account
- company_event: publish a company event fact
- bulk_documents: publish all unpublished documents since a cutoff
- bulk_extractions: publish all unpublished extractions since a cutoff
Requirements: 9.4, 9.5, 10.1
Design ref: Section 4.10 (Lake Publisher), Section 8.4 (Lake publication flow)
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
import asyncpg
import redis.asyncio as aioredis
from minio import Minio
from services.lake_publisher.worker import (
publish_document_extraction,
publish_document_fact,
publish_market_bar,
publish_market_quote,
publish_trade_order,
publish_trade_fill,
publish_pnl_daily,
publish_documents_batch,
publish_document_extractions_batch,
publish_positions_daily_batch,
)
from services.lake_publisher.partitions import partition_values
from services.shared.config import load_config
from services.shared.db import get_minio, get_pg_pool, get_redis
from services.shared.logging import setup_logging
from services.shared.redis_keys import QUEUE_LAKE_PUBLISH, queue_key
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# SQL queries for fetching operational data
# ---------------------------------------------------------------------------
_FETCH_DOCUMENT = """
SELECT
d.id, d.document_type, d.source_type, d.publisher, d.title,
d.url, d.canonical_url, d.language, d.published_at, d.retrieved_at,
d.content_hash, d.parse_quality_score,
COALESCE(
(SELECT dcm.ticker FROM document_company_mentions dcm
WHERE dcm.document_id = d.id LIMIT 1),
''
) AS ticker
FROM documents d
WHERE d.id = $1::uuid
"""
_FETCH_EXTRACTIONS = """
SELECT
di.document_id, dir.ticker, dir.relevance, dir.sentiment,
dir.impact_score, dir.impact_horizon, dir.catalyst_type,
di.confidence, di.novelty_score, di.source_credibility,
dir.key_facts, dir.risks, di.macro_themes,
di.model_name, di.prompt_version, di.schema_version,
di.created_at AS extraction_at,
COALESCE(c.legal_name, '') AS company_name
FROM document_intelligence di
JOIN document_impact_records dir ON dir.intelligence_id = di.id
LEFT JOIN companies c ON c.id = dir.company_id
WHERE di.document_id = $1::uuid
AND di.validation_status = 'valid'
"""
_FETCH_MARKET_SNAPSHOT = """
SELECT
ms.ticker, ms.snapshot_type, ms.data, ms.source_provider, ms.captured_at
FROM market_snapshots ms
WHERE ms.id = $1::uuid
"""
_FETCH_ORDER = """
SELECT
o.id, o.recommendation_id, o.ticker, o.side, o.order_type,
o.quantity, o.limit_price, o.status, o.submitted_at,
o.fill_price, o.fill_quantity, o.filled_at,
COALESCE(ba.account_id, '') AS broker_account,
COALESCE(ba.mode, 'paper') AS execution_mode
FROM orders o
LEFT JOIN broker_accounts ba ON ba.id = o.broker_account_id
WHERE o.id = $1::uuid
"""
_FETCH_ORDER_FILLS = """
SELECT
oe.id AS fill_id, oe.order_id, oe.data, oe.broker_timestamp,
o.ticker, o.side,
COALESCE(ba.account_id, '') AS broker_account
FROM order_events oe
JOIN orders o ON o.id = oe.order_id
LEFT JOIN broker_accounts ba ON ba.id = o.broker_account_id
WHERE oe.order_id = $1::uuid AND oe.event_type = 'fill'
"""
_FETCH_POSITIONS = """
SELECT
p.ticker, p.quantity, p.avg_entry_price, p.current_price,
p.unrealized_pnl, p.realized_pnl,
COALESCE(ba.account_id, '') AS broker_account,
COALESCE(ba.mode, 'paper') AS execution_mode
FROM positions p
LEFT JOIN broker_accounts ba ON ba.id = p.broker_account_id
WHERE p.broker_account_id = $1::uuid AND p.quantity != 0
"""
_FETCH_BULK_DOCUMENTS = """
SELECT
d.id, d.document_type, d.source_type, d.publisher, d.title,
d.url, d.canonical_url, d.language, d.published_at, d.retrieved_at,
d.content_hash, d.parse_quality_score,
COALESCE(
(SELECT dcm.ticker FROM document_company_mentions dcm
WHERE dcm.document_id = d.id LIMIT 1),
''
) AS ticker
FROM documents d
WHERE d.created_at >= $1
AND d.status IN ('parsed', 'extracted')
ORDER BY d.created_at
LIMIT 500
"""
_FETCH_BULK_EXTRACTIONS = """
SELECT
di.document_id, dir.ticker, dir.relevance, dir.sentiment,
dir.impact_score, dir.impact_horizon, dir.catalyst_type,
di.confidence, di.novelty_score, di.source_credibility,
dir.key_facts, dir.risks, di.macro_themes,
di.model_name, di.prompt_version, di.schema_version,
di.created_at AS extraction_at,
COALESCE(c.legal_name, '') AS company_name
FROM document_intelligence di
JOIN document_impact_records dir ON dir.intelligence_id = di.id
LEFT JOIN companies c ON c.id = dir.company_id
WHERE di.created_at >= $1
AND di.validation_status = 'valid'
ORDER BY di.created_at
LIMIT 500
"""
# ---------------------------------------------------------------------------
# Job handlers — each transforms operational rows into lake facts
# ---------------------------------------------------------------------------
def _jsonb_to_str(val: object) -> str:
"""Convert a JSONB column value (list or str) to a comma-separated string."""
if val is None:
return ""
if isinstance(val, str):
try:
parsed = json.loads(val)
if isinstance(parsed, list):
return ", ".join(str(x) for x in parsed)
return val
except (json.JSONDecodeError, TypeError):
return val
if isinstance(val, list):
return ", ".join(str(x) for x in val)
return str(val)
async def publish_document_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> str:
"""Publish a single document metadata fact from PostgreSQL to the lake."""
row = await pool.fetchrow(_FETCH_DOCUMENT, entity_id)
if row is None:
logger.warning("Document %s not found, skipping lake publish", entity_id)
return ""
published_at = row["published_at"] or row["retrieved_at"]
return publish_document_fact(
client=minio_client,
document_id=str(row["id"]),
document_type=row["document_type"],
source_type=row["source_type"],
ticker=row["ticker"] or "",
publisher=row["publisher"] or "",
title=row["title"] or "",
published_at=published_at,
content_hash=row["content_hash"],
url=row["url"] or "",
canonical_url=row["canonical_url"] or "",
language=row["language"] or "en",
confidence=float(row["parse_quality_score"] or 0.0),
retrieved_at=row["retrieved_at"],
)
async def publish_extraction_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> list[str]:
"""Publish document extraction facts for a document from PostgreSQL to the lake."""
rows = await pool.fetch(_FETCH_EXTRACTIONS, entity_id)
if not rows:
logger.info("No valid extractions for document %s", entity_id)
return []
refs: list[str] = []
for row in rows:
ref = publish_document_extraction(
client=minio_client,
document_id=str(row["document_id"]),
ticker=row["ticker"],
sentiment=row["sentiment"] or "neutral",
impact_score=float(row["impact_score"] or 0.0),
catalyst_type=row["catalyst_type"] or "other",
confidence=float(row["confidence"] or 0.0),
extraction_at=row["extraction_at"],
model_name=row["model_name"] or "",
prompt_version=row["prompt_version"] or "",
company_name=row["company_name"] or "",
relevance=float(row["relevance"] or 0.0),
impact_horizon=row["impact_horizon"] or "",
novelty_score=float(row["novelty_score"] or 0.0),
source_credibility=float(row["source_credibility"] or 0.0),
key_facts=_jsonb_to_str(row["key_facts"]),
risks=_jsonb_to_str(row["risks"]),
macro_themes=_jsonb_to_str(row["macro_themes"]),
schema_version=row["schema_version"] or "",
)
refs.append(ref)
return refs
async def publish_market_snapshot_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> list[str]:
"""Publish market bar/quote facts from a market_snapshots row."""
row = await pool.fetchrow(_FETCH_MARKET_SNAPSHOT, entity_id)
if row is None:
logger.warning("Market snapshot %s not found", entity_id)
return []
ticker = row["ticker"]
data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"])
source = row["source_provider"] or ""
captured_at = row["captured_at"]
snapshot_type = row["snapshot_type"]
refs: list[str] = []
if snapshot_type == "bar" or snapshot_type == "bars":
# Single bar or list of bars
bars = data.get("bars", [data]) if "bars" in data else [data]
for bar in bars:
ref = publish_market_bar(
client=minio_client,
ticker=ticker,
open_price=float(bar.get("open", bar.get("o", 0))),
high_price=float(bar.get("high", bar.get("h", 0))),
low_price=float(bar.get("low", bar.get("l", 0))),
close_price=float(bar.get("close", bar.get("c", 0))),
volume=int(bar.get("volume", bar.get("v", 0))),
bar_timestamp=captured_at,
source=source,
vwap=float(bar.get("vwap", bar.get("vw", 0))),
trade_count=int(bar.get("trade_count", bar.get("n", 0))),
bar_interval=bar.get("interval", "1d"),
)
refs.append(ref)
elif snapshot_type == "quote" or snapshot_type == "quotes":
ref = publish_market_quote(
client=minio_client,
ticker=ticker,
bid_price=float(data.get("bid_price", data.get("bp", 0))),
ask_price=float(data.get("ask_price", data.get("ap", 0))),
last_price=float(data.get("last_price", data.get("lp", 0))),
quote_at=captured_at,
source=source,
bid_size=int(data.get("bid_size", data.get("bs", 0))),
ask_size=int(data.get("ask_size", data.get("as", 0))),
last_size=int(data.get("last_size", data.get("ls", 0))),
)
refs.append(ref)
return refs
async def publish_order_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> str:
"""Publish a trade order fact from PostgreSQL to the lake."""
row = await pool.fetchrow(_FETCH_ORDER, entity_id)
if row is None:
logger.warning("Order %s not found", entity_id)
return ""
submitted_at = row["submitted_at"] or datetime.now(timezone.utc)
return publish_trade_order(
client=minio_client,
order_id=str(row["id"]),
ticker=row["ticker"],
side=row["side"],
order_type=row["order_type"],
quantity=float(row["quantity"]),
limit_price=float(row["limit_price"]) if row["limit_price"] else None,
status=row["status"],
broker_account=row["broker_account"],
submitted_at=submitted_at,
recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else "",
execution_mode=row["execution_mode"],
)
async def publish_fills_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> list[str]:
"""Publish trade fill facts for an order from PostgreSQL to the lake."""
rows = await pool.fetch(_FETCH_ORDER_FILLS, entity_id)
if not rows:
logger.info("No fill events for order %s", entity_id)
return []
refs: list[str] = []
for row in rows:
data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"] or "{}")
filled_at = row["broker_timestamp"] or datetime.now(timezone.utc)
ref = publish_trade_fill(
client=minio_client,
fill_id=str(row["fill_id"]),
order_id=str(row["order_id"]),
ticker=row["ticker"],
side=row["side"],
fill_price=float(data.get("fill_price", data.get("price", 0))),
fill_quantity=float(data.get("fill_quantity", data.get("qty", 0))),
broker_account=row["broker_account"],
filled_at=filled_at,
commission=float(data.get("commission", 0)),
)
refs.append(ref)
return refs
async def publish_positions_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> str:
"""Publish daily position snapshots for a broker account."""
rows = await pool.fetch(_FETCH_POSITIONS, entity_id)
if not rows:
logger.info("No open positions for account %s", entity_id)
return ""
snapshot_at = datetime.now(timezone.utc)
positions = [
{
"ticker": row["ticker"],
"quantity": float(row["quantity"]),
"avg_entry_price": float(row["avg_entry_price"] or 0),
"close_price": float(row["current_price"] or 0),
"unrealized_pnl": float(row["unrealized_pnl"] or 0),
}
for row in rows
]
broker_account = rows[0]["broker_account"] if rows else ""
return publish_positions_daily_batch(
client=minio_client,
positions=positions,
broker_account=broker_account,
snapshot_at=snapshot_at,
)
async def publish_pnl_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> list[str]:
"""Publish daily PnL facts for a broker account's positions."""
rows = await pool.fetch(_FETCH_POSITIONS, entity_id)
if not rows:
logger.info("No positions for PnL snapshot, account %s", entity_id)
return []
now = datetime.now(timezone.utc)
refs: list[str] = []
for row in rows:
realized = float(row["realized_pnl"] or 0)
unrealized = float(row["unrealized_pnl"] or 0)
total = realized + unrealized
ref = publish_pnl_daily(
client=minio_client,
ticker=row["ticker"],
realized_pnl=realized,
unrealized_pnl=unrealized,
total_pnl=total,
broker_account=row["broker_account"],
dt=now,
execution_mode=row["execution_mode"],
)
refs.append(ref)
return refs
async def publish_bulk_documents_job(
pool: asyncpg.Pool,
minio_client: Minio,
since: datetime,
) -> list[str]:
"""Publish all documents created since a cutoff as a batch."""
rows = await pool.fetch(_FETCH_BULK_DOCUMENTS, since)
if not rows:
logger.info("No documents to bulk-publish since %s", since)
return []
doc_rows: list[dict[str, object]] = []
for row in rows:
published_at = row["published_at"] or row["retrieved_at"]
doc_rows.append({
"document_id": str(row["id"]),
"document_type": row["document_type"],
"source_type": row["source_type"],
"ticker": row["ticker"] or "",
"publisher": row["publisher"] or "",
"title": row["title"] or "",
"url": row["url"] or "",
"canonical_url": row["canonical_url"] or "",
"language": row["language"] or "en",
"published_at": published_at,
"retrieved_at": row["retrieved_at"],
"content_hash": row["content_hash"],
"confidence": float(row["parse_quality_score"] or 0.0),
**partition_values(published_at),
})
ref = publish_documents_batch(minio_client, doc_rows, since)
return [ref] if ref else []
async def publish_bulk_extractions_job(
pool: asyncpg.Pool,
minio_client: Minio,
since: datetime,
) -> list[str]:
"""Publish all extractions created since a cutoff as a batch."""
rows = await pool.fetch(_FETCH_BULK_EXTRACTIONS, since)
if not rows:
logger.info("No extractions to bulk-publish since %s", since)
return []
extraction_rows: list[dict[str, object]] = []
for row in rows:
model_ver = row["schema_version"] or row["prompt_version"] or ""
extraction_rows.append({
"document_id": str(row["document_id"]),
"ticker": row["ticker"],
"company_name": row["company_name"] or "",
"relevance": float(row["relevance"] or 0.0),
"sentiment": row["sentiment"] or "neutral",
"impact_score": float(row["impact_score"] or 0.0),
"impact_horizon": row["impact_horizon"] or "",
"catalyst_type": row["catalyst_type"] or "other",
"confidence": float(row["confidence"] or 0.0),
"novelty_score": float(row["novelty_score"] or 0.0),
"source_credibility": float(row["source_credibility"] or 0.0),
"key_facts": _jsonb_to_str(row["key_facts"]),
"risks": _jsonb_to_str(row["risks"]),
"macro_themes": _jsonb_to_str(row["macro_themes"]),
"model_name": row["model_name"] or "",
"prompt_version": row["prompt_version"] or "",
"schema_version": row["schema_version"] or "",
"extraction_at": row["extraction_at"],
**partition_values(row["extraction_at"], {"model_version": model_ver}),
})
model_ver = extraction_rows[0].get("model_version", "") if extraction_rows else ""
ref = publish_document_extractions_batch(
minio_client, extraction_rows, since,
model_version=str(model_ver),
)
return [ref] if ref else []
# ---------------------------------------------------------------------------
# Job dispatcher
# ---------------------------------------------------------------------------
JOB_TYPES = {
"document",
"document_extraction",
"market_snapshot",
"trade_order",
"trade_fill",
"positions_snapshot",
"pnl_snapshot",
"company_event",
"bulk_documents",
"bulk_extractions",
}
async def dispatch_job(
pool: asyncpg.Pool,
minio_client: Minio,
job: dict[str, str],
) -> dict[str, object]:
"""Dispatch a lake publish job to the appropriate handler.
Args:
pool: PostgreSQL connection pool.
minio_client: MinIO client for writing Parquet files.
job: Job dict with at least 'job_type' and 'entity_id'.
Returns:
A result dict with 'job_type', 'entity_id', 'refs' (list of s3 URIs),
and 'error' (None on success).
"""
job_type = job.get("job_type", "")
entity_id = job.get("entity_id", "")
since_str = job.get("since")
result: dict[str, object] = {
"job_type": job_type,
"entity_id": entity_id,
"refs": [],
"error": None,
}
try:
if job_type == "document":
ref = await publish_document_job(pool, minio_client, entity_id)
result["refs"] = [ref] if ref else []
elif job_type == "document_extraction":
refs = await publish_extraction_job(pool, minio_client, entity_id)
result["refs"] = refs
elif job_type == "market_snapshot":
refs = await publish_market_snapshot_job(pool, minio_client, entity_id)
result["refs"] = refs
elif job_type == "trade_order":
ref = await publish_order_job(pool, minio_client, entity_id)
result["refs"] = [ref] if ref else []
elif job_type == "trade_fill":
refs = await publish_fills_job(pool, minio_client, entity_id)
result["refs"] = refs
elif job_type == "positions_snapshot":
ref = await publish_positions_job(pool, minio_client, entity_id)
result["refs"] = [ref] if ref else []
elif job_type == "pnl_snapshot":
refs = await publish_pnl_job(pool, minio_client, entity_id)
result["refs"] = refs
elif job_type == "bulk_documents":
since = datetime.fromisoformat(since_str) if since_str else datetime.now(timezone.utc)
refs = await publish_bulk_documents_job(pool, minio_client, since)
result["refs"] = refs
elif job_type == "bulk_extractions":
since = datetime.fromisoformat(since_str) if since_str else datetime.now(timezone.utc)
refs = await publish_bulk_extractions_job(pool, minio_client, since)
result["refs"] = refs
else:
result["error"] = f"Unknown job_type: {job_type}"
logger.warning("Unknown lake publish job type: %s", job_type)
except Exception as exc:
result["error"] = str(exc)
logger.exception("Lake publish job failed: %s/%s", job_type, entity_id)
return result
# ---------------------------------------------------------------------------
# Async worker loop
# ---------------------------------------------------------------------------
async def run_worker(
pool: asyncpg.Pool,
rds: aioredis.Redis,
minio_client: Minio,
poll_interval: float = 2.0,
) -> None:
"""Main worker loop — reads jobs from Redis and dispatches them.
Runs indefinitely until cancelled. Each job is processed sequentially
to keep MinIO write ordering predictable.
"""
queue = queue_key(QUEUE_LAKE_PUBLISH)
logger.info("Lake publisher worker started, listening on %s", queue)
while True:
raw = await rds.lpop(queue) # type: ignore[misc]
if raw is None:
await asyncio.sleep(poll_interval)
continue
try:
job = json.loads(str(raw))
except (json.JSONDecodeError, TypeError):
logger.error("Invalid lake publish job payload: %s", raw)
continue
result = await dispatch_job(pool, minio_client, job)
refs = result.get("refs") or []
error = result.get("error")
if error:
logger.error(
"Lake publish job %s/%s failed: %s",
result["job_type"], result["entity_id"], error,
)
else:
ref_count = len(refs) if isinstance(refs, list) else 0
logger.info(
"Lake publish job %s/%s completed: %d facts written",
result["job_type"], result["entity_id"], ref_count,
)
async def main() -> None:
"""Entry point for the lake publisher worker process."""
config = load_config()
pool = await get_pg_pool(config)
rds = get_redis(config)
minio_client = get_minio(config)
try:
await run_worker(pool, rds, minio_client)
finally:
await pool.close()
await rds.close()
if __name__ == "__main__":
cfg = load_config()
setup_logging("lake_publisher", level=cfg.log_level, json_output=cfg.json_logs)
asyncio.run(main())
+128
View File
@@ -0,0 +1,128 @@
"""Hive-compatible partition layout conventions for the MinIO lakehouse.
Centralizes partition path generation, partition column injection, and
bucket provisioning so that all lake publisher writers produce layouts
that Trino's Hive and Iceberg connectors can discover and prune.
Design ref: Section 5.2, 5.3 (Lakehouse model)
Requirements: 9.4, 9.5, N4, N6
Layout convention:
s3://stonks-lakehouse/warehouse/{table_name}/dt={YYYY-MM-DD}[/{extra_key}={value}]/part-{uuid}.parquet
Rules:
- Every fact table is partitioned by ``dt`` (DATE) derived from the row timestamp.
- Some tables have a second partition key (e.g. ``model_version``).
- Partition columns MUST appear in the Parquet file so Trino can read them
without relying solely on path parsing.
- File names use a UUID suffix to avoid collisions on concurrent writes.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import date, datetime, timezone
LAKEHOUSE_BUCKET = "stonks-lakehouse"
WAREHOUSE_PREFIX = "warehouse"
@dataclass(frozen=True)
class PartitionSpec:
"""Describes the partition layout for a single fact table."""
table_name: str
extra_keys: tuple[str, ...] = field(default_factory=tuple)
@property
def all_keys(self) -> tuple[str, ...]:
"""Return all partition keys in order (dt first, then extras)."""
return ("dt", *self.extra_keys)
# Registry of every analytical fact table and its partition keys.
# This is the single source of truth — DDL, publisher, and tests should agree.
TABLE_PARTITIONS: dict[str, PartitionSpec] = {
"market_bars": PartitionSpec("market_bars"),
"market_quotes": PartitionSpec("market_quotes"),
"company_events": PartitionSpec("company_events"),
"documents": PartitionSpec("documents"),
"document_extractions": PartitionSpec("document_extractions", extra_keys=("model_version",)),
"trade_signals": PartitionSpec("trade_signals"),
"trade_orders": PartitionSpec("trade_orders"),
"trade_fills": PartitionSpec("trade_fills"),
"positions_daily": PartitionSpec("positions_daily"),
"pnl_daily": PartitionSpec("pnl_daily"),
"prediction_vs_outcome": PartitionSpec("prediction_vs_outcome", extra_keys=("model_version",)),
"model_performance": PartitionSpec("model_performance", extra_keys=("model_version",)),
}
def partition_path(
table_name: str,
dt: datetime | date,
extra_partitions: dict[str, str] | None = None,
file_id: str | None = None,
) -> str:
"""Build a Hive-compatible object path for a Parquet file.
Args:
table_name: Logical fact table name (must be in TABLE_PARTITIONS).
dt: Row timestamp or date used to derive the ``dt=`` partition.
extra_partitions: Additional partition key/value pairs (e.g. model_version).
file_id: Optional override for the file suffix (defaults to a UUID4).
Returns:
Object key relative to the bucket root, e.g.
``warehouse/trade_signals/dt=2026-04-11/part-<uuid>.parquet``
"""
spec = TABLE_PARTITIONS.get(table_name)
if spec is None:
raise ValueError(f"Unknown table: {table_name}. Register it in TABLE_PARTITIONS.")
if isinstance(dt, datetime):
dt_str = dt.strftime("%Y-%m-%d")
else:
dt_str = dt.isoformat()
segments = [WAREHOUSE_PREFIX, table_name, f"dt={dt_str}"]
# Append extra partition directories in the order declared by the spec.
extras = extra_partitions or {}
for key in spec.extra_keys:
value = extras.get(key, "__NONE__")
segments.append(f"{key}={value}")
suffix = file_id or uuid.uuid4().hex[:16]
segments.append(f"part-{suffix}.parquet")
return "/".join(segments)
def partition_values(
dt: datetime | date,
extra_partitions: dict[str, str] | None = None,
) -> dict[str, object]:
"""Return partition column values to inject into Parquet row data.
Trino's Hive connector can read partition values from the directory path,
but embedding them in the Parquet file as well ensures compatibility with
engines that don't parse Hive paths (e.g. plain PyArrow reads, DuckDB).
Returns a dict like ``{"dt": date(2026, 4, 11), "model_version": "v2"}``.
"""
if isinstance(dt, datetime):
dt_date = dt.date()
else:
dt_date = dt
values: dict[str, object] = {"dt": dt_date}
if extra_partitions:
values.update(extra_partitions)
return values
def s3_uri(path: str) -> str:
"""Build an s3:// URI from a bucket-relative object path."""
return f"s3://{LAKEHOUSE_BUCKET}/{path}"
File diff suppressed because it is too large Load Diff
+858
View File
@@ -0,0 +1,858 @@
"""HTML-to-text parsing pipeline using BeautifulSoup.
Provides structured HTML parsing with boilerplate removal, metadata extraction,
outbound link extraction, and quality scoring. Inspired by Noctipede crawler
patterns: BeautifulSoup + content hashing, boilerplate stripping, quality scoring.
Requirements: 4.1, 4.2, 4.3
"""
from __future__ import annotations
import json
import logging
import math
import re
from dataclasses import dataclass, field
from urllib.parse import urlparse
from bs4 import BeautifulSoup, Tag
logger = logging.getLogger("html_parser")
# Tags that never contain useful article content
STRIP_TAGS = [
"script", "style", "nav", "footer", "header", "aside",
"iframe", "noscript", "svg", "form", "button",
]
# CSS class / id substrings that signal boilerplate containers
BOILERPLATE_SIGNALS = [
"sidebar", "widget", "advert", "promo", "newsletter",
"social-share", "share-bar", "related-posts", "comment",
"cookie", "popup", "modal", "banner", "breadcrumb",
"pagination", "nav-", "menu", "toolbar", "signup",
"subscribe", "follow-us", "social-media", "share-button",
"ad-slot", "ad-container", "sponsored",
]
# Regex patterns for residual boilerplate in extracted text
BOILERPLATE_TEXT_PATTERNS = [
re.compile(r"(?i)subscribe to our newsletter.*?(?:\n|$)"),
re.compile(r"(?i)click here to read more.*?(?:\n|$)"),
re.compile(r"(?i)advertisement\s*\n?"),
re.compile(r"(?i)copyright ©.*?(?:\n|$)"),
re.compile(r"(?i)all rights reserved.*?(?:\n|$)"),
re.compile(r"(?i)terms of (use|service).*?(?:\n|$)"),
re.compile(r"(?i)privacy policy.*?(?:\n|$)"),
re.compile(r"\s*\[.*?ad.*?\]\s*", re.IGNORECASE),
re.compile(r"(?i)sign up for .*?(?:\n|$)"),
re.compile(r"(?i)follow us on .*?(?:\n|$)"),
re.compile(r"(?i)share this (article|story|post).*?(?:\n|$)"),
re.compile(r"(?i)read more:?\s*$"),
re.compile(r"(?i)recommended for you.*?(?:\n|$)"),
re.compile(r"(?i)you may also like.*?(?:\n|$)"),
re.compile(r"(?i)trending now.*?(?:\n|$)"),
re.compile(r"(?i)most (popular|read).*?(?:\n|$)"),
re.compile(r"(?i)^tags:\s*$"),
re.compile(r"(?i)^\s*photo\s*:.*?(?:\n|$)"),
re.compile(r"(?i)^\s*image\s*(credit|source|courtesy)\s*:.*?(?:\n|$)"),
]
# Selectors for article body candidates, in priority order
ARTICLE_SELECTORS = [
"article",
"[role='main']",
".article-body",
".post-content",
".entry-content",
".story-body",
".article-content",
"#article-body",
"#story-body",
".article-text",
".post-body",
".content-body",
"main",
]
# Minimum text density (text chars / total chars including markup) for a block
# to be considered content-rich rather than boilerplate
_MIN_TEXT_DENSITY = 0.25
# Minimum word count for a block to be a viable body candidate
_MIN_BLOCK_WORDS = 20
@dataclass
class QualitySignals:
"""Individual quality signals contributing to the overall parse score.
Each signal is a float in [0, 1] representing how well the parsed
content performs on that dimension.
Requirements: 4.3
"""
word_count_signal: float = 0.0
diversity_signal: float = 0.0
sentence_signal: float = 0.0
paragraph_signal: float = 0.0
body_found_signal: float = 0.0
metadata_signal: float = 0.0
def as_dict(self) -> dict[str, float]:
return {
"word_count": self.word_count_signal,
"diversity": self.diversity_signal,
"sentence": self.sentence_signal,
"paragraph": self.paragraph_signal,
"body_found": self.body_found_signal,
"metadata": self.metadata_signal,
}
@dataclass
class CompanyMention:
"""A detected company mention in parsed text.
Requirements: 1.3, 4.1
"""
company_id: str
ticker: str
mention_type: str # ticker, legal_name, alias, brand
confidence: float
match_count: int = 1
@dataclass
class ParsedDocument:
"""Result of HTML-to-text parsing pipeline."""
body_text: str = ""
title: str = ""
author: str = ""
publisher: str = ""
published_at: str | None = None
canonical_url: str | None = None
language: str = "en"
description: str = ""
document_type: str = "article"
outbound_links: list[str] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
mentioned_companies: list[CompanyMention] = field(default_factory=list)
quality_score: float = 0.0
confidence: str = "low"
word_count: int = 0
quality_signals: QualitySignals = field(default_factory=QualitySignals)
low_quality_flag: bool = False
quality_warnings: list[str] = field(default_factory=list)
def _attr_str(tag: Tag, attr: str) -> str:
"""Safely get a tag attribute as a joined string."""
val = tag.get(attr, "")
if isinstance(val, list):
return " ".join(val)
return str(val) if val else ""
def _is_boilerplate_container(tag: Tag) -> bool:
"""Check if a tag looks like a boilerplate container by class/id."""
cls = _attr_str(tag, "class").lower()
tag_id = _attr_str(tag, "id").lower()
combined = f"{cls} {tag_id}"
return any(sig in combined for sig in BOILERPLATE_SIGNALS)
def _strip_boilerplate_tags(soup: BeautifulSoup) -> None:
"""Remove known non-content tags and boilerplate containers in-place."""
for tag_name in STRIP_TAGS:
for tag in soup.find_all(tag_name):
tag.decompose()
for tag in soup.find_all(True):
if _is_boilerplate_container(tag):
tag.decompose()
def _reduce_boilerplate_text(text: str) -> str:
"""Apply regex patterns to strip residual boilerplate from extracted text."""
for pattern in BOILERPLATE_TEXT_PATTERNS:
text = pattern.sub("", text)
return text.strip()
def _text_density(tag: Tag) -> float:
"""Compute text density for a tag: ratio of text length to total markup length.
Higher density means more actual text relative to HTML structure,
which is a strong signal for content blocks vs boilerplate.
Requirements: 4.2
"""
markup_len = len(str(tag))
if markup_len == 0:
return 0.0
text_len = len(tag.get_text(strip=True))
return text_len / markup_len
def _link_density(tag: Tag) -> float:
"""Compute link density: ratio of text inside <a> tags to total text.
High link density signals navigation/boilerplate blocks (menus, sidebars).
Low link density signals content paragraphs.
Requirements: 4.2
"""
total_text = len(tag.get_text(strip=True))
if total_text == 0:
return 1.0
link_text = sum(len(a.get_text(strip=True)) for a in tag.find_all("a"))
return link_text / total_text
def _block_score(tag: Tag) -> float:
"""Score a block element as a body candidate using text density heuristics.
Combines text density, link density, paragraph count, and word count
into a composite score. Higher is more likely to be the article body.
Requirements: 4.2
"""
text = tag.get_text(strip=True)
word_count = len(text.split())
if word_count < _MIN_BLOCK_WORDS:
return 0.0
td = _text_density(tag)
ld = _link_density(tag)
p_count = len(tag.find_all("p"))
# Base score from text density (0-1), penalized by link density
score = td * (1.0 - ld)
# Bonus for paragraph-rich blocks (structured article content)
if p_count >= 2:
score += 0.1 * min(p_count, 10)
# Bonus for word count (log-scaled to avoid runaway scores)
score += 0.05 * math.log(max(word_count, 1))
return score
def _find_article_body(soup: BeautifulSoup) -> Tag | None:
"""Find the most likely article body element.
First tries semantic selectors (article, [role=main], etc.).
If no semantic match, falls back to text-density scoring across
candidate block elements to find the content-richest container.
Requirements: 4.2
"""
# Priority 1: semantic selectors
for selector in ARTICLE_SELECTORS:
result = soup.select_one(selector)
if result:
text = result.get_text(strip=True)
if len(text.split()) >= _MIN_BLOCK_WORDS:
return result
# Priority 2: text-density scoring on block-level containers
candidates: list[tuple[float, Tag]] = []
for tag in soup.find_all(["div", "section", "td"]):
score = _block_score(tag)
if score > 0:
candidates.append((score, tag))
if candidates:
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
return None
def _collapse_whitespace(text: str) -> str:
"""Collapse runs of blank lines into single separators."""
lines = [line.strip() for line in text.splitlines()]
result: list[str] = []
prev_blank = False
for line in lines:
if not line:
if not prev_blank:
result.append("")
prev_blank = True
else:
result.append(line)
prev_blank = False
return "\n".join(result).strip()
def _remove_short_orphan_lines(text: str, min_words: int = 3) -> str:
"""Remove very short orphan lines that are likely UI fragments or captions.
Lines shorter than min_words that don't end with sentence punctuation
are stripped. This catches leftover button labels, image captions,
and navigation fragments.
Requirements: 4.2
"""
lines = text.splitlines()
kept: list[str] = []
for line in lines:
stripped = line.strip()
words = stripped.split()
if len(words) < min_words and not stripped.endswith((".", "!", "?", ":")):
continue
kept.append(line)
return "\n".join(kept)
def _detect_repeated_blocks(text: str, min_len: int = 40) -> str:
"""Remove repeated text blocks that appear more than once.
Template text (disclaimers, repeated footers) often appears verbatim
in multiple places. This strips exact duplicate blocks.
Requirements: 4.2
"""
lines = text.splitlines()
seen: dict[str, int] = {}
for line in lines:
stripped = line.strip()
if len(stripped) >= min_len:
seen[stripped] = seen.get(stripped, 0) + 1
duplicates = {k for k, v in seen.items() if v > 1}
if not duplicates:
return text
kept: list[str] = []
emitted: set[str] = set()
for line in lines:
stripped = line.strip()
if stripped in duplicates:
if stripped not in emitted:
kept.append(line)
emitted.add(stripped)
# Skip subsequent duplicates
else:
kept.append(line)
return "\n".join(kept)
def extract_body_text(html: str) -> str:
"""Extract main body text from HTML with boilerplate removal.
Pipeline:
1. Strip non-content tags (script, style, nav, footer, etc.)
2. Strip boilerplate containers by class/id signals
3. Find article body via semantic selectors or text-density scoring
4. Extract text from best candidate
5. Remove residual boilerplate via regex patterns
6. Remove short orphan lines (UI fragments)
7. Detect and collapse repeated template blocks
8. Collapse whitespace
Requirements: 4.1, 4.2
"""
soup = BeautifulSoup(html, "html.parser")
_strip_boilerplate_tags(soup)
article = _find_article_body(soup)
if article:
raw_text = article.get_text(separator="\n", strip=True)
else:
body = soup.find("body")
raw_text = (body or soup).get_text(separator="\n", strip=True)
# Multi-stage text cleaning
text = _reduce_boilerplate_text(raw_text)
text = _remove_short_orphan_lines(text)
text = _detect_repeated_blocks(text)
text = _collapse_whitespace(text)
return text
def extract_metadata(html: str, url: str = "") -> dict[str, str | None]:
"""Extract document metadata from HTML head elements.
Extracts title, author, publisher, published date, canonical URL,
language, description, and tags/keywords.
Requirements: 4.1
"""
soup = BeautifulSoup(html, "html.parser")
meta: dict[str, str | None] = {}
# Title: og:title > <title>
og_title = soup.find("meta", property="og:title")
if og_title and og_title.get("content"):
content = og_title["content"]
meta["title"] = content.strip() if isinstance(content, str) else ""
elif soup.title and soup.title.string:
meta["title"] = soup.title.string.strip()
else:
meta["title"] = ""
# Author
author_tag = soup.find("meta", attrs={"name": "author"})
if author_tag and author_tag.get("content"):
content = author_tag["content"]
meta["author"] = content.strip() if isinstance(content, str) else ""
else:
meta["author"] = ""
# Publisher: og:site_name > hostname
site_name = soup.find("meta", property="og:site_name")
if site_name and site_name.get("content"):
content = site_name["content"]
meta["publisher"] = content.strip() if isinstance(content, str) else ""
else:
meta["publisher"] = urlparse(url).hostname or "" if url else ""
# Published date: article:published_time > JSON-LD datePublished
pub_time = soup.find("meta", property="article:published_time")
if pub_time and pub_time.get("content"):
content = pub_time["content"]
meta["published_at"] = content.strip() if isinstance(content, str) else None
else:
meta["published_at"] = _extract_jsonld_date(soup)
# Canonical URL
canonical = soup.find("link", rel="canonical")
if canonical and canonical.get("href"):
meta["canonical_url"] = str(canonical["href"])
else:
og_url = soup.find("meta", property="og:url")
if og_url and og_url.get("content"):
meta["canonical_url"] = str(og_url["content"])
else:
meta["canonical_url"] = url or None
# Language
html_tag = soup.find("html")
if html_tag and html_tag.get("lang"):
lang = html_tag["lang"]
meta["language"] = str(lang)[:5] if lang else "en"
else:
meta["language"] = "en"
# Description
desc = soup.find("meta", property="og:description") or soup.find(
"meta", attrs={"name": "description"}
)
if desc and desc.get("content"):
content = desc["content"]
meta["description"] = content.strip() if isinstance(content, str) else ""
else:
meta["description"] = ""
# Tags / keywords
keywords = soup.find("meta", attrs={"name": "keywords"})
if keywords and keywords.get("content"):
content = keywords["content"]
raw = content.strip() if isinstance(content, str) else ""
meta["tags"] = raw # comma-separated string
else:
meta["tags"] = ""
return meta
def _extract_jsonld_date(soup: BeautifulSoup) -> str | None:
"""Try to extract datePublished from JSON-LD script tags."""
for script in soup.find_all("script", type="application/ld+json"):
if script.string and "datePublished" in script.string:
try:
ld = json.loads(script.string)
if isinstance(ld, dict) and "datePublished" in ld:
return str(ld["datePublished"])
if isinstance(ld, list):
for item in ld:
if isinstance(item, dict) and "datePublished" in item:
return str(item["datePublished"])
except (json.JSONDecodeError, TypeError):
pass
return None
def extract_outbound_links(html: str, base_url: str = "") -> list[str]:
"""Extract outbound links from HTML, filtering out self-references.
Requirements: 4.1
"""
soup = BeautifulSoup(html, "html.parser")
base_host = urlparse(base_url).hostname or "" if base_url else ""
links: list[str] = []
for a_tag in soup.find_all("a", href=True):
href = str(a_tag["href"]).strip()
if not href or href.startswith("#") or href.startswith("javascript:"):
continue
parsed = urlparse(href)
# Only include absolute URLs that point to different hosts
if parsed.scheme in ("http", "https") and parsed.hostname:
if parsed.hostname != base_host:
links.append(href)
# Dedupe while preserving order
seen: set[str] = set()
unique: list[str] = []
for link in links:
if link not in seen:
seen.add(link)
unique.append(link)
return unique
def _count_sentences(text: str) -> int:
"""Count approximate sentence count by terminal punctuation."""
return len(re.findall(r"[.!?]+(?:\s|$)", text))
def _count_paragraphs(text: str) -> int:
"""Count non-empty paragraph blocks separated by blank lines."""
blocks = re.split(r"\n\s*\n", text.strip())
return sum(1 for b in blocks if len(b.strip().split()) >= 5)
def score_parse_quality(
text: str,
*,
body_found: bool = True,
has_title: bool = False,
has_author: bool = False,
has_publisher: bool = False,
has_published_at: bool = False,
) -> tuple[float, str, QualitySignals, list[str]]:
"""Score parse quality using multiple content and metadata signals.
Returns (score, confidence_label, signals, warnings).
Signals considered:
- word_count_signal: length of extracted text
- diversity_signal: vocabulary richness (unique/total words)
- sentence_signal: presence of proper sentence structure
- paragraph_signal: multi-paragraph structure
- body_found_signal: whether a semantic article body was located
- metadata_signal: presence of title, author, publisher, date
Requirements: 4.3
"""
warnings: list[str] = []
words = text.split()
word_count = len(words)
# --- word count signal ---
if word_count < 20:
wc_sig = 0.1
warnings.append("very_short_text")
elif word_count < 50:
wc_sig = 0.3
warnings.append("short_text")
elif word_count < 150:
wc_sig = 0.6
elif word_count < 300:
wc_sig = 0.8
else:
wc_sig = 1.0
# --- diversity signal ---
if word_count > 0:
unique = len(set(w.lower() for w in words))
diversity = unique / word_count
else:
diversity = 0.0
if diversity < 0.2:
div_sig = 0.2
if word_count >= 20:
warnings.append("low_vocabulary_diversity")
elif diversity < 0.4:
div_sig = 0.5
else:
div_sig = 1.0
# --- sentence signal ---
sentence_count = _count_sentences(text)
if sentence_count == 0:
sent_sig = 0.1
if word_count >= 20:
warnings.append("no_sentence_structure")
elif sentence_count < 3:
sent_sig = 0.5
else:
sent_sig = 1.0
# --- paragraph signal ---
para_count = _count_paragraphs(text)
if para_count == 0:
para_sig = 0.2
elif para_count == 1:
para_sig = 0.5
else:
para_sig = 1.0
# --- body found signal ---
body_sig = 1.0 if body_found else 0.3
if not body_found:
warnings.append("no_article_body_found")
# --- metadata signal ---
meta_hits = sum([has_title, has_author, has_publisher, has_published_at])
meta_sig = meta_hits / 4.0
signals = QualitySignals(
word_count_signal=wc_sig,
diversity_signal=div_sig,
sentence_signal=sent_sig,
paragraph_signal=para_sig,
body_found_signal=body_sig,
metadata_signal=meta_sig,
)
# Weighted composite score
score = (
0.30 * wc_sig
+ 0.15 * div_sig
+ 0.15 * sent_sig
+ 0.10 * para_sig
+ 0.20 * body_sig
+ 0.10 * meta_sig
)
score = round(min(score, 0.95), 2)
# Confidence label
if score < 0.35:
confidence = "low"
elif score < 0.65:
confidence = "medium"
else:
confidence = "high"
return score, confidence, signals, warnings
def score_quality(text: str) -> tuple[float, str]:
"""Score parse quality based on extracted text characteristics.
Returns (score, confidence_label) where confidence is low/medium/high.
Thin wrapper around score_parse_quality for backward compatibility.
Requirements: 4.3
"""
score, confidence, _signals, _warnings = score_parse_quality(text)
return score, confidence
def infer_document_type(html: str, url: str = "") -> str:
"""Infer document type from URL patterns and HTML content.
Requirements: 4.1
"""
url_lower = url.lower()
if any(kw in url_lower for kw in ["sec.gov", "edgar", "filing", "10-k", "10-q", "8-k"]):
return "filing"
if any(kw in url_lower for kw in ["transcript", "earnings-call", "earnings_call"]):
return "transcript"
if any(kw in url_lower for kw in ["press-release", "press_release", "newsroom"]):
return "press_release"
# html reserved for future content-based inference
_ = html
return "article"
def parse_html(html: str, url: str = "", aliases: list[dict[str, str]] | None = None) -> ParsedDocument:
"""Full HTML-to-text parsing pipeline.
Combines body extraction, metadata extraction, link extraction,
quality scoring, document type inference, and company mention
detection into a single result.
Requirements: 1.3, 4.1, 4.2, 4.3
"""
soup = BeautifulSoup(html, "html.parser")
_strip_boilerplate_tags(soup)
article = _find_article_body(soup)
body_found = article is not None
if article:
raw_text = article.get_text(separator="\n", strip=True)
else:
body = soup.find("body")
raw_text = (body or soup).get_text(separator="\n", strip=True)
# Multi-stage text cleaning
text = _reduce_boilerplate_text(raw_text)
text = _remove_short_orphan_lines(text)
text = _detect_repeated_blocks(text)
text = _collapse_whitespace(text)
metadata = extract_metadata(html, url)
outbound_links = extract_outbound_links(html, url)
doc_type = infer_document_type(html, url)
word_count = len(text.split())
tags_raw = metadata.get("tags", "") or ""
tags = [t.strip() for t in tags_raw.split(",") if t.strip()] if tags_raw else []
# Rich quality scoring with all available signals
quality, confidence, signals, warnings = score_parse_quality(
text,
body_found=body_found,
has_title=bool(metadata.get("title")),
has_author=bool(metadata.get("author")),
has_publisher=bool(metadata.get("publisher")),
has_published_at=bool(metadata.get("published_at")),
)
low_quality_flag = confidence == "low"
# Company mention detection
mentioned: list[CompanyMention] = []
if aliases and text:
# Search title + body for mentions
search_text = f"{metadata.get('title', '')} {text}"
raw_mentions = detect_company_mentions(search_text, aliases)
for m in raw_mentions:
mentioned.append(CompanyMention(
company_id=str(m["company_id"]),
ticker=str(m["ticker"]),
mention_type=str(m["mention_type"]),
confidence=float(m["confidence"]),
match_count=int(m["match_count"]),
))
return ParsedDocument(
body_text=text,
title=metadata.get("title", "") or "",
author=metadata.get("author", "") or "",
publisher=metadata.get("publisher", "") or "",
published_at=metadata.get("published_at"),
canonical_url=metadata.get("canonical_url"),
language=metadata.get("language", "en") or "en",
description=metadata.get("description", "") or "",
document_type=doc_type,
outbound_links=outbound_links,
tags=tags,
mentioned_companies=mentioned,
quality_score=quality,
confidence=confidence,
word_count=word_count,
quality_signals=signals,
low_quality_flag=low_quality_flag,
quality_warnings=warnings,
)
@dataclass
class AliasEntry:
"""A company alias used for mention detection."""
company_id: str
alias: str
alias_type: str = "alias"
ticker: str = ""
# Confidence by alias type — tickers are most precise, brands least
_CONFIDENCE_BY_TYPE: dict[str, float] = {
"ticker": 0.9,
"legal_name": 0.85,
"alias": 0.7,
"brand": 0.6,
}
def _build_alias_entries(aliases: list[dict[str, str]]) -> list[AliasEntry]:
"""Convert raw alias dicts to typed AliasEntry objects."""
entries: list[AliasEntry] = []
for a in aliases:
alias_val = a.get("alias", "")
if not alias_val:
continue
entries.append(AliasEntry(
company_id=a.get("company_id", ""),
alias=alias_val,
alias_type=a.get("alias_type", "alias"),
ticker=a.get("ticker", ""),
))
return entries
def _count_matches(text: str, pattern: re.Pattern[str]) -> int:
"""Count non-overlapping matches of pattern in text."""
return len(pattern.findall(text))
def detect_company_mentions(
text: str,
aliases: list[dict[str, str]],
) -> list[dict[str, str | float | int]]:
"""Detect company mentions using ticker, alias, and name matching.
Matching strategy by alias length:
- 1-2 chars: case-sensitive word-boundary match (avoids "A" matching "a")
- 3-4 chars: case-insensitive word-boundary match (standard tickers)
- 5+ chars: case-insensitive substring match (company names, brands)
Confidence varies by alias_type: ticker > legal_name > alias > brand.
Multiple alias hits for the same company are deduplicated, keeping the
highest-confidence match and summing match counts.
Requirements: 1.3, 4.1
"""
if not text:
return []
entries = _build_alias_entries(aliases)
text_upper = text.upper()
# Track best match per company: company_id -> (confidence, ticker, mention_type, count)
best: dict[str, tuple[float, str, str, int]] = {}
for entry in entries:
alias = entry.alias
alias_type = entry.alias_type
base_confidence = _CONFIDENCE_BY_TYPE.get(alias_type, 0.7)
match_count = 0
if len(alias) <= 2:
# Very short: case-sensitive word boundary
pattern = re.compile(r"\b" + re.escape(alias) + r"\b")
match_count = _count_matches(text, pattern)
elif len(alias) <= 4:
# Standard ticker length: case-insensitive word boundary
pattern = re.compile(r"\b" + re.escape(alias.upper()) + r"\b")
match_count = _count_matches(text_upper, pattern)
else:
# Longer names: case-insensitive substring
alias_up = alias.upper()
match_count = text_upper.count(alias_up)
if match_count == 0:
continue
cid = entry.company_id
existing = best.get(cid)
if existing is None:
best[cid] = (base_confidence, entry.ticker, alias_type, match_count)
else:
# Keep highest confidence, accumulate match count
prev_conf, prev_ticker, prev_type, prev_count = existing
if base_confidence > prev_conf:
best[cid] = (base_confidence, entry.ticker, alias_type, prev_count + match_count)
else:
best[cid] = (prev_conf, prev_ticker, prev_type, prev_count + match_count)
mentions: list[dict[str, str | float | int]] = []
for cid, (confidence, ticker, mention_type, count) in best.items():
mentions.append({
"company_id": cid,
"ticker": ticker,
"mention_type": mention_type,
"confidence": confidence,
"match_count": count,
})
return mentions
+108 -107
View File
@@ -1,84 +1,41 @@
"""Parser worker - HTML-to-text, boilerplate reduction, quality scoring."""
"""Parser worker - HTML-to-text, boilerplate reduction, quality scoring.
Uses BeautifulSoup-based parsing pipeline for structured HTML extraction,
metadata extraction, outbound link extraction, and quality scoring.
Persists normalized text and structured parser output to MinIO,
and updates document metadata in PostgreSQL.
Requirements: 4.1, 4.2, 4.3, 9.1, 9.2
"""
import asyncio
import io
import json
import logging
import re
from datetime import datetime
from typing import List, Optional, Tuple
import time
from datetime import datetime, timezone
from typing import Any, Optional
import asyncpg
import httpx
import redis.asyncio as aioredis
from minio import Minio
from services.parser.html_parser import ParsedDocument, detect_company_mentions, parse_html
from services.shared.config import load_config
from services.shared.db import get_minio, get_pg_pool, get_redis
from services.shared.logging import Span, extract_trace_context, inject_trace_context, new_trace_id, set_trace_context, setup_logging
from services.shared.metrics import (
ACTIVE_JOBS,
PARSE_DURATION,
PARSE_JOBS_TOTAL,
PARSE_LOW_QUALITY_TOTAL,
PARSE_QUALITY_SCORE,
)
from services.shared.metadata import update_document_parse_results
from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_PARSING, queue_key
from services.shared.storage import upload_normalized_text, upload_parser_output
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("parser_worker")
# Simple boilerplate patterns to strip
BOILERPLATE_PATTERNS = [
re.compile(r"(?i)subscribe to our newsletter.*?(?:\n|$)"),
re.compile(r"(?i)click here to read more.*?(?:\n|$)"),
re.compile(r"(?i)advertisement\s*\n"),
re.compile(r"(?i)copyright ©.*?(?:\n|$)"),
re.compile(r"(?i)all rights reserved.*?(?:\n|$)"),
re.compile(r"(?i)terms of (use|service).*?(?:\n|$)"),
re.compile(r"(?i)privacy policy.*?(?:\n|$)"),
re.compile(r"\s*\[.*?ad.*?\]\s*", re.IGNORECASE),
]
def strip_html_tags(html: str) -> str:
"""Basic HTML tag removal."""
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<[^>]+>", " ", text)
text = re.sub(r"&nbsp;", " ", text)
text = re.sub(r"&amp;", "&", text)
text = re.sub(r"&lt;", "<", text)
text = re.sub(r"&gt;", ">", text)
text = re.sub(r"&#\d+;", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def reduce_boilerplate(text: str) -> str:
for pattern in BOILERPLATE_PATTERNS:
text = pattern.sub("", text)
return text.strip()
def score_quality(text: str) -> Tuple[float, str]:
"""Score parse quality. Returns (score, confidence_label)."""
word_count = len(text.split())
if word_count < 20:
return 0.1, "low"
if word_count < 50:
return 0.3, "low"
if word_count < 150:
return 0.6, "medium"
return 0.85, "high"
def detect_company_mentions(text: str, aliases: List[dict]) -> List[dict]:
"""Detect company mentions using ticker, alias, and name matching."""
mentions = []
text_upper = text.upper()
for alias_info in aliases:
alias = alias_info["alias"]
if alias.upper() in text_upper:
mentions.append({
"company_id": alias_info["company_id"],
"ticker": alias_info.get("ticker", ""),
"mention_type": alias_info.get("alias_type", "alias"),
"confidence": 0.7,
})
return mentions
async def fetch_html(url: str) -> Optional[str]:
"""Fetch article HTML for scraping."""
@@ -94,48 +51,65 @@ async def fetch_html(url: str) -> Optional[str]:
return None
def build_parser_output_json(parsed: ParsedDocument, mentions: list[dict[str, Any]]) -> dict[str, Any]:
"""Build a structured JSON dict from ParsedDocument and detected mentions.
This captures the full parser output for audit and downstream use:
metadata, quality signals, warnings, outbound links, tags, and mentions.
"""
return {
"title": parsed.title,
"author": parsed.author,
"publisher": parsed.publisher,
"published_at": parsed.published_at,
"canonical_url": parsed.canonical_url,
"language": parsed.language,
"description": parsed.description,
"document_type": parsed.document_type,
"word_count": parsed.word_count,
"outbound_links": parsed.outbound_links,
"tags": parsed.tags,
"quality_score": parsed.quality_score,
"confidence": parsed.confidence,
"low_quality_flag": parsed.low_quality_flag,
"quality_warnings": parsed.quality_warnings,
"quality_signals": parsed.quality_signals.as_dict(),
"mentioned_companies": mentions,
}
async def process_job(
job: dict,
job: dict[str, Any],
pool: asyncpg.Pool,
rds: aioredis.Redis,
minio_client: Minio,
):
) -> None:
doc_id = job["document_id"]
ticker = job["ticker"]
url = job.get("url", "")
now = datetime.now(timezone.utc)
_parse_start = time.monotonic()
set_trace_context(trace_id=job.get("_trace_id") or new_trace_id())
# Fetch HTML if we have a URL
html = await fetch_html(url) if url else None
if html:
# Store raw HTML
html_bytes = html.encode("utf-8")
now = datetime.utcnow()
html_path = f"scrape/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/raw.html"
minio_client.put_object(
"stonks-raw-news", html_path, io.BytesIO(html_bytes), len(html_bytes),
content_type="text/html",
)
# Parse
text = strip_html_tags(html)
text = reduce_boilerplate(text)
# Parse using BeautifulSoup pipeline
parsed = parse_html(html, url)
else:
text = ""
parsed = ParsedDocument()
quality_score, confidence = score_quality(text)
text = parsed.body_text
# Store normalized text
# Upload normalized text to MinIO
norm_ref: str | None = None
if text:
text_bytes = text.encode("utf-8")
now = datetime.utcnow()
norm_path = f"parsed/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/normalized.txt"
minio_client.put_object(
"stonks-normalized", norm_path, io.BytesIO(text_bytes), len(text_bytes),
content_type="text/plain",
norm_ref = upload_normalized_text(
minio_client, ticker, doc_id,
text.encode("utf-8"), timestamp=now,
)
else:
norm_path = None
# Detect company mentions
aliases = await pool.fetch(
@@ -150,14 +124,24 @@ async def process_job(
)
mentions = detect_company_mentions(text, [dict(a) for a in aliases]) if text else []
# Update document
status = "parsed" if confidence != "low" else "low_quality"
await pool.execute(
"""UPDATE documents SET
normalized_storage_ref=$2, parse_quality_score=$3, parse_confidence=$4, status=$5, updated_at=NOW()
WHERE id=$1""",
doc_id, f"s3://stonks-normalized/{norm_path}" if norm_path else None,
quality_score, confidence, status,
# Build and upload structured parser output JSON
output_json = build_parser_output_json(parsed, mentions)
output_bytes = json.dumps(output_json, default=str, indent=2).encode("utf-8")
parser_output_ref = upload_parser_output(
minio_client, ticker, doc_id,
output_bytes, timestamp=now,
)
# Update document in PostgreSQL
status = "parsed" if parsed.confidence != "low" else "low_quality"
await update_document_parse_results(
pool,
document_id=doc_id,
normalized_storage_ref=norm_ref,
parser_output_ref=parser_output_ref,
parse_quality_score=parsed.quality_score,
parse_confidence=parsed.confidence,
status=status,
)
# Insert company mentions
@@ -169,19 +153,36 @@ async def process_job(
)
# Only enqueue for extraction if quality is acceptable
if confidence != "low":
await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps({
if parsed.confidence != "low":
await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps(inject_trace_context({
"document_id": doc_id,
"ticker": ticker,
"normalized_text": text[:8000], # Truncate for prompt
}))
logger.info(f"Parsed doc {doc_id} for {ticker}: quality={quality_score:.2f}, confidence={confidence}")
"normalized_text": text[:8000],
})))
PARSE_JOBS_TOTAL.labels(status="parsed").inc()
PARSE_QUALITY_SCORE.observe(parsed.quality_score)
PARSE_DURATION.observe(time.monotonic() - _parse_start)
logger.info(
"Parsed doc %s for %s: quality=%.2f, confidence=%s",
doc_id, ticker, parsed.quality_score, parsed.confidence,
extra={"ticker": ticker, "document_id": doc_id},
)
else:
logger.warning(f"Low quality parse for doc {doc_id}, skipping extraction")
PARSE_JOBS_TOTAL.labels(status="low_quality").inc()
PARSE_LOW_QUALITY_TOTAL.inc()
PARSE_QUALITY_SCORE.observe(parsed.quality_score)
PARSE_DURATION.observe(time.monotonic() - _parse_start)
logger.warning(
"Low quality parse for doc %s, skipping extraction",
doc_id,
extra={"ticker": ticker, "document_id": doc_id},
)
async def main():
async def main() -> None:
config = load_config()
setup_logging("parser_worker", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
rds = get_redis(config)
minio_client = get_minio(config)
@@ -197,7 +198,7 @@ async def main():
try:
await process_job(job, pool, rds, minio_client)
except Exception as e:
logger.error(f"Parse error: {e}")
logger.error("Parse error: %s", e, exc_info=True)
else:
await asyncio.sleep(2)
finally:
+354
View File
@@ -0,0 +1,354 @@
"""Deterministic recommendation eligibility logic.
Evaluates trend summaries against configurable thresholds to decide:
- Whether a recommendation should be generated at all
- What action type (buy/sell/hold/watch) is appropriate
- What execution mode (informational/paper_eligible/live_eligible) is allowed
- Position sizing guidance based on portfolio rules
All decisions are rule-based with no model involvement. The LLM is only
used downstream for optional thesis wording (a separate task).
Requirements: 7.1, 7.2, 7.3, 7.4
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from services.shared.schemas import (
ActionType,
PositionSizing,
RecommendationMode,
TrendDirection,
TrendSummary,
)
class RejectionReason(str, Enum):
"""Why a trend summary was deemed ineligible for a recommendation."""
LOW_CONFIDENCE = "low_confidence"
LOW_TREND_STRENGTH = "low_trend_strength"
HIGH_CONTRADICTION = "high_contradiction"
INSUFFICIENT_EVIDENCE = "insufficient_evidence"
NEUTRAL_DIRECTION = "neutral_direction"
@dataclass(frozen=True)
class EligibilityConfig:
"""Tunable thresholds for recommendation eligibility.
All thresholds are deterministic — no model inference involved.
"""
# --- Gate thresholds (below these → no recommendation) ---
min_confidence: float = 0.35
min_trend_strength: float = 0.10
max_contradiction_score: float = 0.60
min_evidence_count: int = 2 # combined supporting + opposing
# --- Action mapping thresholds ---
# Trend strength above this → buy/sell; below → hold/watch
action_strength_threshold: float = 0.25
# Confidence above this → hold (rather than watch) for weak signals
hold_confidence_threshold: float = 0.50
# --- Mode escalation thresholds ---
# Confidence required for paper_eligible (below → informational)
paper_confidence_threshold: float = 0.50
# Confidence required for live_eligible (below → paper at most)
live_confidence_threshold: float = 0.70
# Contradiction must be below this for live eligibility
live_max_contradiction: float = 0.25
# Minimum evidence count for live eligibility
live_min_evidence: int = 5
# --- Position sizing rules (Requirement 7.3) ---
# Base portfolio allocation percentage
base_portfolio_pct: float = 0.02
# Maximum portfolio allocation percentage
max_portfolio_pct: float = 0.05
# Base max loss percentage
base_max_loss_pct: float = 0.005
# Maximum max loss percentage
max_max_loss_pct: float = 0.01
# Confidence scaling: higher confidence → larger position (linear)
confidence_sizing_weight: float = 0.5
# Contradiction penalty: higher contradiction → smaller position
contradiction_sizing_penalty: float = 0.3
DEFAULT_ELIGIBILITY_CONFIG = EligibilityConfig()
@dataclass
class EligibilityResult:
"""Output of the deterministic eligibility evaluation.
Captures the decision, the reasoning, and all inputs used so the
full decision trace is reproducible (Requirement 8.3).
"""
eligible: bool
action: ActionType
mode: RecommendationMode
position_sizing: PositionSizing
rejection_reasons: list[RejectionReason] = field(default_factory=list)
time_horizon: str = ""
invalidation_conditions: list[str] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Gate checks
# ---------------------------------------------------------------------------
def _check_gates(
summary: TrendSummary,
config: EligibilityConfig,
) -> list[RejectionReason]:
"""Apply hard gate checks. Returns a list of rejection reasons (empty = pass)."""
reasons: list[RejectionReason] = []
if summary.confidence < config.min_confidence:
reasons.append(RejectionReason.LOW_CONFIDENCE)
if summary.trend_strength < config.min_trend_strength:
reasons.append(RejectionReason.LOW_TREND_STRENGTH)
if summary.contradiction_score > config.max_contradiction_score:
reasons.append(RejectionReason.HIGH_CONTRADICTION)
evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
if evidence_count < config.min_evidence_count:
reasons.append(RejectionReason.INSUFFICIENT_EVIDENCE)
if summary.trend_direction == TrendDirection.NEUTRAL:
reasons.append(RejectionReason.NEUTRAL_DIRECTION)
return reasons
# ---------------------------------------------------------------------------
# Action mapping
# ---------------------------------------------------------------------------
def _determine_action(
summary: TrendSummary,
config: EligibilityConfig,
) -> ActionType:
"""Map trend direction and strength to an action type.
Strong bullish → BUY, strong bearish → SELL.
Weak but directional → HOLD if confidence is decent, else WATCH.
Mixed → WATCH.
"""
direction = summary.trend_direction
strength = summary.trend_strength
if direction == TrendDirection.MIXED:
return ActionType.WATCH
if direction == TrendDirection.NEUTRAL:
return ActionType.WATCH
strong_signal = strength >= config.action_strength_threshold
if direction == TrendDirection.BULLISH:
if strong_signal:
return ActionType.BUY
return ActionType.HOLD if summary.confidence >= config.hold_confidence_threshold else ActionType.WATCH
if direction == TrendDirection.BEARISH:
if strong_signal:
return ActionType.SELL
return ActionType.HOLD if summary.confidence >= config.hold_confidence_threshold else ActionType.WATCH
return ActionType.WATCH
# ---------------------------------------------------------------------------
# Mode escalation
# ---------------------------------------------------------------------------
def _determine_mode(
summary: TrendSummary,
action: ActionType,
config: EligibilityConfig,
) -> RecommendationMode:
"""Determine the highest execution mode allowed.
WATCH and HOLD actions are always informational — they don't trigger trades.
BUY/SELL can escalate to paper_eligible or live_eligible based on
confidence, contradiction, and evidence thresholds.
"""
if action in (ActionType.WATCH, ActionType.HOLD):
return RecommendationMode.INFORMATIONAL
evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
# Check live eligibility first (strictest)
if (
summary.confidence >= config.live_confidence_threshold
and summary.contradiction_score <= config.live_max_contradiction
and evidence_count >= config.live_min_evidence
):
return RecommendationMode.LIVE_ELIGIBLE
# Check paper eligibility
if summary.confidence >= config.paper_confidence_threshold:
return RecommendationMode.PAPER_ELIGIBLE
return RecommendationMode.INFORMATIONAL
# ---------------------------------------------------------------------------
# Position sizing (Requirement 7.3)
# ---------------------------------------------------------------------------
def _compute_position_sizing(
summary: TrendSummary,
config: EligibilityConfig,
) -> PositionSizing:
"""Compute position sizing guidance from portfolio rules and signal quality.
Higher confidence → larger allocation (up to max).
Higher contradiction → smaller allocation (penalty).
"""
# Start from base allocation
confidence_scale = config.base_portfolio_pct + (
config.confidence_sizing_weight
* summary.confidence
* (config.max_portfolio_pct - config.base_portfolio_pct)
)
# Apply contradiction penalty
contradiction_penalty = config.contradiction_sizing_penalty * summary.contradiction_score
portfolio_pct = confidence_scale * (1.0 - contradiction_penalty)
# Clamp to bounds
portfolio_pct = max(config.base_portfolio_pct * 0.5, min(portfolio_pct, config.max_portfolio_pct))
# Max loss scales similarly
loss_scale = config.base_max_loss_pct + (
config.confidence_sizing_weight
* summary.confidence
* (config.max_max_loss_pct - config.base_max_loss_pct)
)
max_loss_pct = loss_scale * (1.0 - contradiction_penalty)
max_loss_pct = max(config.base_max_loss_pct * 0.5, min(max_loss_pct, config.max_max_loss_pct))
return PositionSizing(
portfolio_pct=round(portfolio_pct, 6),
max_loss_pct=round(max_loss_pct, 6),
)
# ---------------------------------------------------------------------------
# Time horizon mapping
# ---------------------------------------------------------------------------
_WINDOW_TO_HORIZON: dict[str, str] = {
"intraday": "intraday",
"1d": "swing_1d_3d",
"7d": "swing_1d_10d",
"30d": "position_10d_30d",
"90d": "position_30d_90d",
}
def _map_time_horizon(window: str) -> str:
"""Map a trend window to a human-readable time horizon label."""
return _WINDOW_TO_HORIZON.get(window, f"window_{window}")
# ---------------------------------------------------------------------------
# Invalidation conditions
# ---------------------------------------------------------------------------
def _derive_invalidation_conditions(
summary: TrendSummary,
action: ActionType,
) -> list[str]:
"""Generate deterministic invalidation conditions for the recommendation.
These describe when the recommendation should be considered stale or wrong.
"""
conditions: list[str] = []
if action == ActionType.BUY:
conditions.append(
f"Trend direction for {summary.entity_id} reverses to bearish"
)
elif action == ActionType.SELL:
conditions.append(
f"Trend direction for {summary.entity_id} reverses to bullish"
)
if summary.contradiction_score > 0.0:
conditions.append(
f"Contradiction score exceeds 0.60 (currently {summary.contradiction_score:.2f})"
)
if summary.confidence > 0.0:
conditions.append(
f"Confidence drops below {summary.confidence * 0.7:.2f}"
)
if summary.material_risks:
conditions.append(
f"Material risk materialises: {summary.material_risks[0]}"
)
return conditions
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def evaluate_eligibility(
summary: TrendSummary,
config: EligibilityConfig = DEFAULT_ELIGIBILITY_CONFIG,
) -> EligibilityResult:
"""Evaluate a trend summary for recommendation eligibility.
This is the single deterministic entry point. It:
1. Applies gate checks (confidence, strength, contradiction, evidence)
2. Maps trend direction + strength to an action type
3. Determines the highest allowed execution mode
4. Computes position sizing from portfolio rules
5. Derives invalidation conditions
Returns an EligibilityResult with the full decision trace.
"""
rejection_reasons = _check_gates(summary, config)
# Even if rejected, we still compute action/mode for the trace
action = _determine_action(summary, config)
mode = _determine_mode(summary, action, config)
sizing = _compute_position_sizing(summary, config)
horizon = _map_time_horizon(summary.window.value)
invalidation = _derive_invalidation_conditions(summary, action)
eligible = len(rejection_reasons) == 0
# If not eligible, force mode to informational (Requirement 7.4)
if not eligible:
mode = RecommendationMode.INFORMATIONAL
return EligibilityResult(
eligible=eligible,
action=action,
mode=mode,
position_sizing=sizing,
rejection_reasons=rejection_reasons,
time_horizon=horizon,
invalidation_conditions=invalidation,
)
+71
View File
@@ -0,0 +1,71 @@
"""Recommendation worker entrypoint - polls Redis for recommendation jobs."""
from __future__ import annotations
import asyncio
import json
import logging
import asyncpg
from minio import Minio
from services.recommendation.worker import generate_recommendation
from services.shared.config import load_config
from services.shared.logging import setup_logging
from services.shared.redis_keys import QUEUE_RECOMMENDATION, queue_key
logger = logging.getLogger("recommendation_main")
async def main() -> None:
config = load_config()
setup_logging("recommendation", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
minio_client = Minio(
config.minio.endpoint,
access_key=config.minio.access_key,
secret_key=config.minio.secret_key,
secure=config.minio.secure,
)
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_RECOMMENDATION)
logger.info("Recommendation worker started, polling %s", queue)
try:
while True:
raw = await redis_client.lpop(queue)
if raw is None:
await asyncio.sleep(1)
continue
payload = raw
job = json.loads(payload)
ticker = job.get("ticker", "")
window = job.get("window", "7d")
logger.info("Processing recommendation job for %s/%s", ticker, window)
try:
rec = await generate_recommendation(
pool, ticker, window,
minio_client=minio_client,
)
if rec:
logger.info(
"Recommendation generated for %s: %s %s",
ticker, rec.action.value, rec.mode.value,
)
else:
logger.info("No recommendation generated for %s (no trend data)", ticker)
except Exception:
logger.exception("Recommendation failed for %s", ticker)
finally:
await pool.close()
await redis_client.close()
if __name__ == "__main__":
asyncio.run(main())
+241
View File
@@ -0,0 +1,241 @@
"""Suppression logic for low-quality data or low confidence.
Evaluates the quality of the underlying data feeding a trend summary
and suppresses automated trade eligibility when data quality is poor.
Suppressed recommendations are marked as informational only.
This layer runs *before* the eligibility engine and acts as a pre-filter
on data quality. The eligibility engine handles signal-level thresholds
(confidence, strength, contradiction); this module handles data-level
quality concerns (stale evidence, low extraction quality, poor source
diversity, insufficient valid documents).
Requirements: 7.4
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from services.shared.schemas import TrendSummary
logger = logging.getLogger(__name__)
class SuppressionReason(str, Enum):
"""Why a recommendation was suppressed due to data quality."""
LOW_DATA_CONFIDENCE = "low_data_confidence"
STALE_EVIDENCE = "stale_evidence"
LOW_SOURCE_DIVERSITY = "low_source_diversity"
HIGH_EXTRACTION_FAILURE_RATE = "high_extraction_failure_rate"
INSUFFICIENT_VALID_DOCUMENTS = "insufficient_valid_documents"
@dataclass(frozen=True)
class SuppressionConfig:
"""Tunable thresholds for data quality suppression.
These thresholds focus on the quality of the *input data* rather
than the trend signal itself (which is handled by EligibilityConfig).
"""
# Minimum average extraction confidence across evidence documents.
# Below this, the underlying data is too unreliable for trade decisions.
min_avg_extraction_confidence: float = 0.40
# Maximum age (hours) of the most recent evidence document.
# If the freshest evidence is older than this, the trend is stale.
max_evidence_staleness_hours: float = 168.0 # 7 days
# Minimum number of distinct source types (e.g. news, filings, market)
# represented in the evidence. Low diversity means the signal may be
# driven by a single unreliable source class.
min_source_types: int = 1
# Maximum tolerable extraction failure rate (0-1).
# If more than this fraction of documents failed extraction,
# the data pipeline is unreliable for this ticker.
max_extraction_failure_rate: float = 0.50
# Minimum number of valid (non-failed) documents that contributed
# to the trend. Below this, there isn't enough data to act on.
min_valid_documents: int = 2
# Overall data quality confidence threshold.
# The computed data quality score must exceed this for the
# recommendation to be eligible for automated trading.
min_data_quality_score: float = 0.30
DEFAULT_SUPPRESSION_CONFIG = SuppressionConfig()
@dataclass
class DataQualityContext:
"""Quality metrics about the data underlying a trend summary.
Populated by querying document and extraction metadata for the
ticker and window. When not available from the database, callers
can construct this from the trend summary itself.
"""
total_documents: int = 0
valid_documents: int = 0
failed_documents: int = 0
avg_extraction_confidence: float = 0.0
newest_evidence_at: datetime | None = None
source_types: set[str] = field(default_factory=set)
@dataclass
class SuppressionResult:
"""Output of the suppression evaluation."""
suppressed: bool
reasons: list[SuppressionReason] = field(default_factory=list)
data_quality_score: float = 0.0
context: DataQualityContext | None = None
def build_quality_context_from_summary(
summary: TrendSummary,
) -> DataQualityContext:
"""Build a minimal DataQualityContext from a TrendSummary.
This is a fallback when full document-level quality metrics aren't
available. It uses the trend summary's evidence counts and confidence
as proxies.
"""
total = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
return DataQualityContext(
total_documents=total,
valid_documents=total,
failed_documents=0,
avg_extraction_confidence=summary.confidence,
newest_evidence_at=summary.generated_at,
source_types=set(),
)
def _compute_data_quality_score(
ctx: DataQualityContext,
config: SuppressionConfig,
reference_time: datetime,
) -> float:
"""Compute an overall data quality score from the context.
Returns a value in [0, 1] where higher is better quality.
Components:
- Extraction confidence (40% weight)
- Evidence freshness (30% weight)
- Document coverage (30% weight)
"""
# Extraction confidence component
conf_component = min(ctx.avg_extraction_confidence / 0.8, 1.0)
# Freshness component
if ctx.newest_evidence_at is not None:
if ctx.newest_evidence_at.tzinfo is None:
newest = ctx.newest_evidence_at.replace(tzinfo=timezone.utc)
else:
newest = ctx.newest_evidence_at
age_hours = (reference_time - newest).total_seconds() / 3600.0
max_hours = config.max_evidence_staleness_hours
freshness_component = max(0.0, 1.0 - (age_hours / max_hours))
else:
freshness_component = 0.0
# Document coverage component
if ctx.total_documents > 0:
valid_ratio = ctx.valid_documents / ctx.total_documents
count_factor = min(ctx.valid_documents / 10.0, 1.0)
coverage_component = valid_ratio * count_factor
else:
coverage_component = 0.0
score = (0.4 * conf_component) + (0.3 * freshness_component) + (0.3 * coverage_component)
return round(max(0.0, min(1.0, score)), 4)
def evaluate_suppression(
summary: TrendSummary,
quality_ctx: DataQualityContext | None = None,
config: SuppressionConfig = DEFAULT_SUPPRESSION_CONFIG,
reference_time: datetime | None = None,
) -> SuppressionResult:
"""Evaluate whether a recommendation should be suppressed due to data quality.
Checks multiple data quality dimensions and returns a SuppressionResult
indicating whether the recommendation should be suppressed and why.
Args:
summary: The trend summary to evaluate.
quality_ctx: Data quality context. If None, a minimal context is
built from the trend summary itself.
config: Suppression thresholds.
reference_time: Reference time for staleness checks.
Returns:
SuppressionResult with suppression decision and reasons.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
ctx = quality_ctx or build_quality_context_from_summary(summary)
reasons: list[SuppressionReason] = []
# Check average extraction confidence
if ctx.avg_extraction_confidence < config.min_avg_extraction_confidence:
reasons.append(SuppressionReason.LOW_DATA_CONFIDENCE)
# Check evidence staleness
if ctx.newest_evidence_at is not None:
newest = ctx.newest_evidence_at
if newest.tzinfo is None:
newest = newest.replace(tzinfo=timezone.utc)
age_hours = (reference_time - newest).total_seconds() / 3600.0
if age_hours > config.max_evidence_staleness_hours:
reasons.append(SuppressionReason.STALE_EVIDENCE)
elif ctx.total_documents > 0:
# Have documents but no timestamp — treat as stale
reasons.append(SuppressionReason.STALE_EVIDENCE)
# Check source diversity
if len(ctx.source_types) < config.min_source_types and ctx.total_documents > 0:
reasons.append(SuppressionReason.LOW_SOURCE_DIVERSITY)
# Check extraction failure rate
if ctx.total_documents > 0:
failure_rate = ctx.failed_documents / ctx.total_documents
if failure_rate > config.max_extraction_failure_rate:
reasons.append(SuppressionReason.HIGH_EXTRACTION_FAILURE_RATE)
# Check minimum valid documents
if ctx.valid_documents < config.min_valid_documents:
reasons.append(SuppressionReason.INSUFFICIENT_VALID_DOCUMENTS)
# Compute overall data quality score
quality_score = _compute_data_quality_score(ctx, config, reference_time)
# If quality score is below threshold, add a general suppression reason
if quality_score < config.min_data_quality_score and SuppressionReason.LOW_DATA_CONFIDENCE not in reasons:
reasons.append(SuppressionReason.LOW_DATA_CONFIDENCE)
suppressed = len(reasons) > 0
if suppressed:
logger.info(
"Recommendation suppressed for %s/%s: reasons=%s quality_score=%.3f",
summary.entity_id, summary.window.value,
[r.value for r in reasons], quality_score,
)
return SuppressionResult(
suppressed=suppressed,
reasons=reasons,
data_quality_score=quality_score,
context=ctx,
)
+175
View File
@@ -0,0 +1,175 @@
"""Optional LLM wording layer for thesis generation.
Takes a deterministic thesis string (built from trend data and eligibility
rules) and rewrites it into natural, analyst-quality prose using a local
Ollama model. The deterministic thesis is always preserved as the fallback
and audit reference.
This module is opt-in: callers must explicitly request LLM rewriting.
If the LLM call fails or is disabled, the original deterministic thesis
is returned unchanged.
Requirements: 7.1, 7.2
"""
from __future__ import annotations
import logging
import time
import httpx
from services.shared.config import OllamaConfig
from services.shared.schemas import TrendSummary
logger = logging.getLogger(__name__)
THESIS_PROMPT_VERSION = "thesis-rewrite-v1"
THESIS_SYSTEM_PROMPT = """\
You are a concise financial analyst. You rewrite structured trade thesis \
summaries into clear, professional prose suitable for an internal research note.
STRICT RULES:
1. Do NOT add any information that is not present in the input.
2. Do NOT fabricate numbers, dates, company names, or analyst opinions.
3. Keep the rewrite under 150 words.
4. Preserve all factual claims, risk notes, and evidence counts from the input.
5. Use a neutral, professional tone. Avoid hype or marketing language.
6. Return ONLY the rewritten thesis text. No JSON, no markdown, no commentary."""
def build_thesis_rewrite_prompt(
deterministic_thesis: str,
summary: TrendSummary,
) -> dict[str, str]:
"""Build system and user prompts for thesis rewriting.
Provides the model with the deterministic thesis and key trend
context so it can produce a natural-language version.
"""
context_parts = [
f"Ticker: {summary.entity_id}",
f"Window: {summary.window.value}",
f"Direction: {summary.trend_direction.value}",
f"Strength: {summary.trend_strength:.2f}",
f"Confidence: {summary.confidence:.2f}",
f"Contradiction score: {summary.contradiction_score:.2f}",
]
if summary.dominant_catalysts:
context_parts.append(f"Catalysts: {', '.join(summary.dominant_catalysts[:3])}")
if summary.material_risks:
context_parts.append(f"Risks: {'; '.join(summary.material_risks[:2])}")
context_block = "\n".join(context_parts)
user_prompt = f"""\
Rewrite the following structured thesis into clear, professional analyst prose.
--- STRUCTURED THESIS ---
{deterministic_thesis}
--- END STRUCTURED THESIS ---
--- CONTEXT ---
{context_block}
--- END CONTEXT ---
Return ONLY the rewritten thesis. No other text."""
return {
"system": THESIS_SYSTEM_PROMPT,
"user": user_prompt,
}
async def rewrite_thesis_with_llm(
deterministic_thesis: str,
summary: TrendSummary,
config: OllamaConfig,
http_client: httpx.AsyncClient | None = None,
) -> str:
"""Rewrite a deterministic thesis using a local Ollama model.
If the LLM call fails for any reason, returns the original
deterministic thesis unchanged. This ensures the LLM layer is
purely additive and never blocks recommendation generation.
Args:
deterministic_thesis: The rule-based thesis string.
summary: The trend summary that produced the thesis.
config: Ollama connection and model configuration.
http_client: Optional shared HTTP client for connection reuse.
Returns:
The LLM-rewritten thesis on success, or the original on failure.
"""
prompts = build_thesis_rewrite_prompt(deterministic_thesis, summary)
owns_client = http_client is None
client = http_client or httpx.AsyncClient(timeout=config.timeout)
try:
rewritten = await _call_ollama_thesis(client, config, prompts)
if rewritten:
logger.info(
"LLM thesis rewrite succeeded for %s (%d chars → %d chars)",
summary.entity_id,
len(deterministic_thesis),
len(rewritten),
)
return rewritten
logger.warning(
"LLM thesis rewrite returned empty for %s — using deterministic thesis",
summary.entity_id,
)
return deterministic_thesis
except Exception:
logger.exception(
"LLM thesis rewrite failed for %s — using deterministic thesis",
summary.entity_id,
)
return deterministic_thesis
finally:
if owns_client:
await client.aclose()
async def _call_ollama_thesis(
client: httpx.AsyncClient,
config: OllamaConfig,
prompts: dict[str, str],
) -> str:
"""Make a single Ollama chat call for thesis rewriting.
Returns the model's text response, or empty string on failure.
"""
start = time.monotonic()
payload = {
"model": config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"stream": False,
}
resp = await client.post(
f"{config.base_url}/api/chat",
json=payload,
)
_ = resp.raise_for_status()
duration_ms = int((time.monotonic() - start) * 1000)
body: dict[str, object] = resp.json()
msg = body.get("message")
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
logger.debug(
"Ollama thesis call completed in %dms, response length=%d",
duration_ms,
len(content),
)
return content.strip()
+721 -1
View File
@@ -1 +1,721 @@
"""Recommendation worker - generates explainable trade recommendations from trend data."""
"""Recommendation worker - generates explainable trade recommendations from trend data.
Fetches the latest trend summaries for a ticker, evaluates eligibility
using deterministic rules, builds Recommendation objects with thesis
and evidence citations, and persists them to PostgreSQL.
Requirements: 7.1, 7.2, 7.3, 7.4
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
import asyncpg
from services.recommendation.eligibility import (
EligibilityConfig,
EligibilityResult,
evaluate_eligibility,
)
from services.recommendation.suppression import (
DataQualityContext,
SuppressionConfig,
SuppressionResult,
evaluate_suppression,
)
from services.recommendation.thesis_llm import (
THESIS_PROMPT_VERSION,
rewrite_thesis_with_llm,
)
from minio import Minio
from services.lake_publisher.worker import publish_recommendation_facts
from services.shared.config import OllamaConfig
from services.shared.schemas import (
ModelMetadata,
PositionSizing,
Recommendation,
RecommendationMode,
TrendDirection,
TrendSummary,
TrendWindow,
)
from services.shared.metrics import (
RECOMMENDATION_CONFIDENCE,
RECOMMENDATION_GENERATED,
RECOMMENDATION_SUPPRESSED,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Fetch latest trend summary for a ticker + window
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Fetch data quality context for suppression checks
# ---------------------------------------------------------------------------
_DATA_QUALITY_QUERY = """
SELECT
COUNT(*) AS total_documents,
COUNT(*) FILTER (WHERE di.validation_status = 'valid') AS valid_documents,
COUNT(*) FILTER (WHERE di.validation_status = 'failed') AS failed_documents,
AVG(di.confidence) FILTER (WHERE di.validation_status = 'valid') AS avg_extraction_confidence,
MAX(d.published_at) AS newest_evidence_at,
ARRAY_AGG(DISTINCT s.source_class) FILTER (WHERE s.source_class IS NOT NULL) AS source_types
FROM documents d
JOIN document_intelligence di ON di.document_id = d.id
LEFT JOIN sources s ON d.source_id = s.id
WHERE d.id = ANY(
SELECT UNNEST(
COALESCE(tw.top_supporting_evidence, '[]'::jsonb)
|| COALESCE(tw.top_opposing_evidence, '[]'::jsonb)
)::uuid
FROM trend_windows tw
WHERE tw.entity_id = $1 AND tw.window = $2
ORDER BY tw.generated_at DESC
LIMIT 1
)
"""
async def fetch_data_quality_context(
pool: asyncpg.Pool,
ticker: str,
window: str,
) -> DataQualityContext | None:
"""Fetch data quality metrics for the documents underlying a trend.
Returns None if the query fails or returns no data, in which case
the suppression module will fall back to summary-based estimation.
"""
try:
row = await pool.fetchrow(_DATA_QUALITY_QUERY, ticker, window)
if row is None or row["total_documents"] == 0:
return None
source_types_raw = row["source_types"]
source_types = set(source_types_raw) if source_types_raw else set()
return DataQualityContext(
total_documents=int(row["total_documents"]),
valid_documents=int(row["valid_documents"] or 0),
failed_documents=int(row["failed_documents"] or 0),
avg_extraction_confidence=float(row["avg_extraction_confidence"] or 0.0),
newest_evidence_at=row["newest_evidence_at"],
source_types=source_types,
)
except Exception:
logger.warning(
"Failed to fetch data quality context for %s/%s — will use summary fallback",
ticker, window, exc_info=True,
)
return None
_LATEST_TREND_QUERY = """
SELECT
entity_type, entity_id, window, trend_direction, trend_strength,
confidence, top_supporting_evidence, top_opposing_evidence,
dominant_catalysts, material_risks, contradiction_score,
disagreement_details, market_context, generated_at
FROM trend_windows
WHERE entity_id = $1 AND window = $2
ORDER BY generated_at DESC
LIMIT 1
"""
def _parse_trend_row(row: asyncpg.Record) -> TrendSummary:
"""Convert a trend_windows row into a TrendSummary."""
supporting = row["top_supporting_evidence"]
if isinstance(supporting, str):
supporting = json.loads(supporting)
opposing = row["top_opposing_evidence"]
if isinstance(opposing, str):
opposing = json.loads(opposing)
catalysts = row["dominant_catalysts"]
if isinstance(catalysts, str):
catalysts = json.loads(catalysts)
risks = row["material_risks"]
if isinstance(risks, str):
risks = json.loads(risks)
return TrendSummary(
entity_type=row["entity_type"],
entity_id=row["entity_id"],
window=TrendWindow(row["window"]),
trend_direction=TrendDirection(row["trend_direction"]),
trend_strength=float(row["trend_strength"]),
confidence=float(row["confidence"]),
top_supporting_evidence=supporting or [],
top_opposing_evidence=opposing or [],
dominant_catalysts=catalysts or [],
material_risks=risks or [],
contradiction_score=float(row["contradiction_score"] or 0.0),
generated_at=row["generated_at"],
)
async def fetch_latest_trend(
pool: asyncpg.Pool,
ticker: str,
window: str,
) -> TrendSummary | None:
"""Fetch the most recent trend summary for a ticker and window."""
row = await pool.fetchrow(_LATEST_TREND_QUERY, ticker, window)
if row is None:
return None
return _parse_trend_row(row)
# ---------------------------------------------------------------------------
# Build thesis from trend summary (deterministic, no LLM)
# ---------------------------------------------------------------------------
def build_thesis(
summary: TrendSummary,
result: EligibilityResult,
) -> str:
"""Generate a deterministic thesis string from trend data.
This is the descriptive analysis portion (Requirement 7.2).
The LLM wording layer is a separate optional task.
"""
direction = summary.trend_direction.value
ticker = summary.entity_id
window = summary.window.value
strength = summary.trend_strength
confidence = summary.confidence
parts: list[str] = []
# Opening: direction and strength
parts.append(
f"{ticker} shows a {direction} trend over the {window} window "
+ f"with strength {strength:.2f} and confidence {confidence:.2f}."
)
# Catalysts
if summary.dominant_catalysts:
catalyst_str = ", ".join(summary.dominant_catalysts[:3])
parts.append(f"Dominant catalysts: {catalyst_str}.")
# Contradiction note (Requirement 7.2 — separate descriptive from prescriptive)
if summary.contradiction_score > 0.15:
parts.append(
"Notable signal disagreement detected "
+ f"(contradiction score: {summary.contradiction_score:.2f})."
)
# Risks
if summary.material_risks:
risk_str = "; ".join(summary.material_risks[:2])
parts.append(f"Key risks: {risk_str}.")
# Evidence count
supporting_count = len(summary.top_supporting_evidence)
opposing_count = len(summary.top_opposing_evidence)
parts.append(
f"Based on {supporting_count} supporting and "
+ f"{opposing_count} opposing evidence documents."
)
# Prescriptive action (separated per Requirement 7.2)
action = result.action.value.upper()
mode = result.mode.value.replace("_", " ")
parts.append(f"Recommendation: {action} ({mode}).")
return " ".join(parts)
# ---------------------------------------------------------------------------
# Build risk classification (Requirement 7.2)
# ---------------------------------------------------------------------------
def classify_risk(
summary: TrendSummary,
result: EligibilityResult,
) -> str:
"""Assign a risk classification label based on signal quality.
Returns one of: low, moderate, high, very_high.
"""
score = 0.0
# Contradiction raises risk
score += summary.contradiction_score * 2.0
# Low confidence raises risk
score += (1.0 - summary.confidence) * 1.5
# Low evidence count raises risk
evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence)
if evidence_count < 3:
score += 1.0
elif evidence_count < 5:
score += 0.5
# Rejection reasons raise risk
score += len(result.rejection_reasons) * 0.5
if score >= 3.0:
return "very_high"
if score >= 2.0:
return "high"
if score >= 1.0:
return "moderate"
return "low"
# ---------------------------------------------------------------------------
# Build Recommendation from eligibility result
# ---------------------------------------------------------------------------
def build_recommendation(
summary: TrendSummary,
result: EligibilityResult,
reference_time: datetime | None = None,
llm_thesis: str | None = None,
suppression_result: SuppressionResult | None = None,
) -> Recommendation:
"""Assemble a Recommendation object from a trend summary and eligibility result.
Combines all evidence refs (supporting + opposing) into the recommendation
so the full decision trace is available (Requirement 8.3).
If ``llm_thesis`` is provided (from the optional LLM wording layer),
it replaces the deterministic thesis text while preserving the risk
classification prefix.
If ``suppression_result`` indicates suppression, a suppression note
is appended to the thesis for audit visibility (Requirement 7.4).
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
# Combine evidence refs — supporting first, then opposing
evidence_refs = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence)
deterministic_thesis = build_thesis(summary, result)
risk_class = classify_risk(summary, result)
# Use LLM-rewritten thesis if available, otherwise deterministic
thesis_body = llm_thesis if llm_thesis else deterministic_thesis
# Append suppression note if suppressed (Requirement 7.4)
if suppression_result and suppression_result.suppressed:
reason_strs = [r.value for r in suppression_result.reasons]
thesis_body += (
f" [SUPPRESSED: data quality below threshold "
f"(score={suppression_result.data_quality_score:.2f}, "
f"reasons={', '.join(reason_strs)})]"
)
# Track whether the thesis was LLM-generated for audit
if llm_thesis:
provider = "ollama"
model_name = "thesis-rewrite"
prompt_version = THESIS_PROMPT_VERSION
else:
provider = "deterministic"
model_name = "eligibility-v1"
prompt_version = ""
return Recommendation(
ticker=summary.entity_id,
action=result.action,
mode=result.mode,
confidence=summary.confidence,
time_horizon=result.time_horizon,
thesis=f"[risk:{risk_class}] {thesis_body}",
invalidation_conditions=result.invalidation_conditions,
position_sizing=PositionSizing(
portfolio_pct=result.position_sizing.portfolio_pct,
max_loss_pct=result.position_sizing.max_loss_pct,
),
evidence_refs=evidence_refs,
model_metadata=ModelMetadata(
provider=provider,
model_name=model_name,
prompt_version=prompt_version,
schema_version="1.0.0",
),
generated_at=reference_time,
)
# ---------------------------------------------------------------------------
# Persist recommendation to PostgreSQL
# ---------------------------------------------------------------------------
_INSERT_RECOMMENDATION = """
INSERT INTO recommendations (
ticker, action, mode, confidence, time_horizon,
thesis, invalidation_conditions, portfolio_pct, max_loss_pct,
model_version, model_provider, prompt_version, schema_version,
risk_classification, generated_at
) VALUES (
$1, $2, $3, $4, $5,
$6, $7::jsonb, $8, $9,
$10, $11, $12, $13,
$14, $15
)
RETURNING id
"""
_INSERT_REC_EVIDENCE = """
INSERT INTO recommendation_evidence (
recommendation_id, document_id, evidence_type, weight
) VALUES ($1, $2::uuid, $3, $4)
"""
_INSERT_RISK_EVALUATION = """
INSERT INTO risk_evaluations (
recommendation_id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at
) VALUES ($1::uuid, $2, $3, $4::jsonb, $5::jsonb, $6)
"""
_FETCH_RECOMMENDATION = """
SELECT
id, ticker, action, mode, confidence, time_horizon,
thesis, invalidation_conditions, portfolio_pct, max_loss_pct,
model_version, model_provider, prompt_version, schema_version,
risk_classification, generated_at
FROM recommendations
WHERE id = $1::uuid
"""
_FETCH_REC_EVIDENCE = """
SELECT document_id, evidence_type, weight
FROM recommendation_evidence
WHERE recommendation_id = $1::uuid
ORDER BY evidence_type, weight DESC
"""
_FETCH_LATEST_RECS_FOR_TICKER = """
SELECT
id, ticker, action, mode, confidence, time_horizon,
thesis, invalidation_conditions, portfolio_pct, max_loss_pct,
model_version, model_provider, prompt_version, schema_version,
risk_classification, generated_at
FROM recommendations
WHERE ticker = $1
ORDER BY generated_at DESC
LIMIT $2
"""
def _extract_risk_classification(thesis: str) -> str:
"""Extract the risk classification from the thesis prefix."""
if thesis.startswith("[risk:"):
end = thesis.find("]")
if end > 6:
return thesis[6:end]
return "moderate"
async def persist_recommendation(
pool: asyncpg.Pool,
rec: Recommendation,
supporting_ids: list[str],
opposing_ids: list[str],
eligibility_result: EligibilityResult | None = None,
) -> str:
"""Insert a recommendation, evidence citations, and risk evaluation.
Persists the full model metadata and risk classification for audit
trail (Requirement 8.3). Also writes the eligibility decision to
the risk_evaluations table when provided.
Returns the recommendation UUID.
"""
risk_class = _extract_risk_classification(rec.thesis)
row = await pool.fetchrow(
_INSERT_RECOMMENDATION,
rec.ticker,
rec.action.value,
rec.mode.value,
rec.confidence,
rec.time_horizon,
rec.thesis,
json.dumps(rec.invalidation_conditions),
rec.position_sizing.portfolio_pct,
rec.position_sizing.max_loss_pct,
rec.model_metadata.model_name,
rec.model_metadata.provider,
rec.model_metadata.prompt_version,
rec.model_metadata.schema_version,
risk_class,
rec.generated_at,
)
rec_id = str(row["id"])
# Insert evidence citations with position-based weighting
evidence_rows: list[tuple[str, str, str, float]] = []
for idx, doc_id in enumerate(supporting_ids):
weight = round(1.0 / (1.0 + idx * 0.1), 4) # rank decay
evidence_rows.append((rec_id, doc_id, "supporting", weight))
for idx, doc_id in enumerate(opposing_ids):
weight = round(1.0 / (1.0 + idx * 0.1), 4)
evidence_rows.append((rec_id, doc_id, "opposing", weight))
if evidence_rows:
await pool.executemany(_INSERT_REC_EVIDENCE, evidence_rows)
# Persist the eligibility/risk evaluation for audit trail
if eligibility_result is not None:
rejection_reasons_json = json.dumps(
[r.value for r in eligibility_result.rejection_reasons]
)
risk_checks = {
"time_horizon": eligibility_result.time_horizon,
"position_sizing": {
"portfolio_pct": eligibility_result.position_sizing.portfolio_pct,
"max_loss_pct": eligibility_result.position_sizing.max_loss_pct,
},
"invalidation_conditions": eligibility_result.invalidation_conditions,
"risk_classification": risk_class,
}
await pool.execute(
_INSERT_RISK_EVALUATION,
rec_id,
eligibility_result.eligible,
eligibility_result.mode.value,
rejection_reasons_json,
json.dumps(risk_checks),
rec.generated_at,
)
return rec_id
async def fetch_recommendation_by_id(
pool: asyncpg.Pool,
recommendation_id: str,
) -> dict[str, object] | None:
"""Fetch a persisted recommendation with its evidence citations.
Returns a dict with the recommendation fields and an 'evidence' list,
or None if not found.
"""
row = await pool.fetchrow(_FETCH_RECOMMENDATION, recommendation_id)
if row is None:
return None
rec_dict = dict(row)
# Parse JSONB fields
if isinstance(rec_dict.get("invalidation_conditions"), str):
rec_dict["invalidation_conditions"] = json.loads(rec_dict["invalidation_conditions"])
# Fetch evidence
evidence_rows = await pool.fetch(_FETCH_REC_EVIDENCE, recommendation_id)
rec_dict["evidence"] = [
{
"document_id": str(e["document_id"]),
"evidence_type": e["evidence_type"],
"weight": float(e["weight"]),
}
for e in evidence_rows
]
return rec_dict
async def fetch_latest_recommendations(
pool: asyncpg.Pool,
ticker: str,
limit: int = 10,
) -> list[dict[str, object]]:
"""Fetch the most recent recommendations for a ticker.
Returns a list of recommendation dicts (without evidence — use
fetch_recommendation_by_id for full detail).
"""
rows = await pool.fetch(_FETCH_LATEST_RECS_FOR_TICKER, ticker, limit)
results = []
for row in rows:
rec_dict = dict(row)
if isinstance(rec_dict.get("invalidation_conditions"), str):
rec_dict["invalidation_conditions"] = json.loads(rec_dict["invalidation_conditions"])
results.append(rec_dict)
return results
# ---------------------------------------------------------------------------
# Main entry point: generate recommendation for a ticker
# ---------------------------------------------------------------------------
async def generate_recommendation(
pool: asyncpg.Pool,
ticker: str,
window: str = TrendWindow.SEVEN_DAY.value,
config: EligibilityConfig | None = None,
reference_time: datetime | None = None,
ollama_config: OllamaConfig | None = None,
suppression_config: SuppressionConfig | None = None,
minio_client: Minio | None = None,
) -> Recommendation | None:
"""Generate and persist a recommendation for a ticker from its latest trend.
Steps:
1. Fetch the latest trend summary for the ticker + window.
2. Evaluate data quality suppression (Requirement 7.4).
3. Evaluate eligibility using deterministic rules.
4. Build a Recommendation object with thesis and evidence.
- If ``ollama_config`` is provided, the deterministic thesis is
rewritten into analyst-quality prose via the LLM wording layer.
5. Persist the recommendation and evidence citations.
Returns the Recommendation, or None if no trend data exists.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
cfg = config or EligibilityConfig()
sup_cfg = suppression_config or SuppressionConfig()
# 1. Fetch latest trend
summary = await fetch_latest_trend(pool, ticker, window)
if summary is None:
logger.info("No trend data for %s/%s — skipping recommendation", ticker, window)
return None
# 2. Evaluate data quality suppression (Requirement 7.4)
quality_ctx = await fetch_data_quality_context(pool, ticker, window)
suppression = evaluate_suppression(
summary, quality_ctx=quality_ctx, config=sup_cfg, reference_time=reference_time,
)
# 3. Evaluate eligibility
result = evaluate_eligibility(summary, cfg)
# Apply suppression: force mode to informational if suppressed
if suppression.suppressed:
result = EligibilityResult(
eligible=False,
action=result.action,
mode=RecommendationMode.INFORMATIONAL,
position_sizing=result.position_sizing,
rejection_reasons=result.rejection_reasons,
time_horizon=result.time_horizon,
invalidation_conditions=result.invalidation_conditions,
)
# 4. Optional LLM thesis rewrite
llm_thesis: str | None = None
if ollama_config is not None:
deterministic_thesis = build_thesis(summary, result)
llm_thesis = await rewrite_thesis_with_llm(
deterministic_thesis=deterministic_thesis,
summary=summary,
config=ollama_config,
)
# If the LLM returned the same text as the deterministic thesis,
# treat it as a no-op (fallback was used).
if llm_thesis == deterministic_thesis:
llm_thesis = None
# 5. Build recommendation
rec = build_recommendation(
summary, result, reference_time, llm_thesis=llm_thesis,
suppression_result=suppression,
)
# 6. Persist recommendation, evidence citations, and risk evaluation
rec_id = await persist_recommendation(
pool,
rec,
supporting_ids=list(summary.top_supporting_evidence),
opposing_ids=list(summary.top_opposing_evidence),
eligibility_result=result,
)
# 7. Publish prediction facts to analytical tables (Requirement 9.4)
if minio_client is not None:
try:
lake_refs = publish_recommendation_facts(
minio_client,
rec,
trend_direction=summary.trend_direction.value,
trend_strength=summary.trend_strength,
)
logger.info(
"Published analytical facts for %s: %s",
ticker, lake_refs,
)
except Exception:
logger.warning(
"Failed to publish analytical facts for %s/%s — recommendation "
"persisted but lake publication failed",
ticker, rec_id, exc_info=True,
)
logger.info(
"Generated recommendation %s for %s: action=%s mode=%s confidence=%.3f "
"eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s",
rec_id, ticker, rec.action.value, rec.mode.value, rec.confidence,
result.eligible, suppression.suppressed, suppression.data_quality_score,
llm_thesis is not None,
)
# Prometheus metrics
RECOMMENDATION_GENERATED.labels(action=rec.action.value, mode=rec.mode.value).inc()
RECOMMENDATION_CONFIDENCE.observe(rec.confidence)
if suppression.suppressed:
RECOMMENDATION_SUPPRESSED.inc()
return rec
# ---------------------------------------------------------------------------
# Batch: generate recommendations for multiple tickers
# ---------------------------------------------------------------------------
async def generate_recommendations_batch(
pool: asyncpg.Pool,
tickers: list[str],
window: str = TrendWindow.SEVEN_DAY.value,
config: EligibilityConfig | None = None,
ollama_config: OllamaConfig | None = None,
suppression_config: SuppressionConfig | None = None,
minio_client: Minio | None = None,
) -> list[Recommendation]:
"""Generate recommendations for a list of tickers.
Processes each ticker sequentially. Returns only the successfully
generated recommendations (tickers with no trend data are skipped).
If ``ollama_config`` is provided, each recommendation's thesis will
be rewritten using the LLM wording layer.
"""
results: list[Recommendation] = []
reference_time = datetime.now(timezone.utc)
for ticker in tickers:
rec = await generate_recommendation(
pool, ticker, window, config, reference_time,
ollama_config=ollama_config,
suppression_config=suppression_config,
minio_client=minio_client,
)
if rec is not None:
results.append(rec)
logger.info(
"Batch recommendation: %d/%d tickers produced recommendations",
len(results), len(tickers),
)
return results
+101
View File
@@ -0,0 +1,101 @@
"""Risk Engine API - FastAPI application for order risk evaluation and approval workflow."""
from __future__ import annotations
from contextlib import asynccontextmanager
import asyncpg
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from services.risk.approval import (
expire_stale_approvals,
get_approval_by_id,
get_pending_approvals,
review_approval,
)
from services.risk.engine import (
AccountRiskState,
PortfolioRiskConfig,
ProposedOrder,
RiskEvaluation,
evaluate_order,
)
from services.shared.config import load_config
from services.shared.logging import setup_logging
config = load_config()
pool: asyncpg.Pool | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool
setup_logging("risk_engine", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
yield
if pool:
await pool.close()
app = FastAPI(title="Stonks Oracle - Risk Engine", lifespan=lifespan)
class EvaluateRequest(BaseModel):
order: ProposedOrder
config: PortfolioRiskConfig | None = None
state: AccountRiskState | None = None
@app.post("/evaluate", response_model=RiskEvaluation)
async def evaluate(req: EvaluateRequest) -> RiskEvaluation:
risk_config = req.config or PortfolioRiskConfig()
return evaluate_order(req.order, risk_config, req.state)
@app.get("/health")
async def health():
return {"status": "ok"}
class ReviewRequest(BaseModel):
approved: bool
reviewed_by: str = "operator"
review_note: str = ""
@app.get("/approvals/pending")
async def list_pending():
if not pool:
raise HTTPException(503, "Database not ready")
requests = await get_pending_approvals(pool)
return [r.to_dict() for r in requests]
@app.get("/approvals/{approval_id}")
async def get_approval(approval_id: str):
if not pool:
raise HTTPException(503, "Database not ready")
req = await get_approval_by_id(pool, approval_id)
if not req:
raise HTTPException(404, "Approval not found")
return req.to_dict()
@app.post("/approvals/{approval_id}/review")
async def review(approval_id: str, body: ReviewRequest):
if not pool:
raise HTTPException(503, "Database not ready")
status = await review_approval(
pool, approval_id, body.approved, body.reviewed_by, body.review_note,
)
if status is None:
raise HTTPException(404, "Approval not found or no longer pending")
return {"approval_id": approval_id, "status": status.value}
@app.post("/approvals/expire")
async def expire():
if not pool:
raise HTTPException(503, "Database not ready")
expired = await expire_stale_approvals(pool)
return {"expired": expired}
+300
View File
@@ -0,0 +1,300 @@
"""Operator approval workflow for live trading mode.
When live trading is enabled and operator approval is required,
orders are held in a pending state until an operator explicitly
approves or rejects them. Expired approvals are treated as rejections.
Requirements: 8.2
Design: Section 4.8 - Risk Engine (operator approval rules)
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any
import asyncpg
from services.risk.engine import (
OperatorApproval,
PortfolioRiskConfig,
TradingMode,
)
logger = logging.getLogger("operator_approval")
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class ApprovalStatus(str, Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"
EXPIRED = "expired"
# ---------------------------------------------------------------------------
# Core logic: does this order need approval?
# ---------------------------------------------------------------------------
def requires_approval(
config: PortfolioRiskConfig,
trading_mode: TradingMode | None = None,
) -> bool:
"""Determine whether an order requires operator approval.
Paper orders are auto-approved when auto_approve_paper is True.
Live orders require approval when require_approval_for_live is True.
Disabled mode always returns False (orders are blocked upstream).
"""
mode = trading_mode or config.trading_mode
if mode == TradingMode.DISABLED:
return False
if mode == TradingMode.PAPER:
return not config.operator_approval.auto_approve_paper
# Live mode
return config.operator_approval.require_approval_for_live
def compute_expiry(
config: PortfolioRiskConfig,
now: datetime | None = None,
) -> datetime:
"""Compute the expiry timestamp for a new approval request."""
now = now or datetime.now(timezone.utc)
return now + timedelta(minutes=config.operator_approval.approval_timeout_minutes)
# ---------------------------------------------------------------------------
# Approval request model (in-memory representation)
# ---------------------------------------------------------------------------
class ApprovalRequest:
"""Represents a pending operator approval request."""
def __init__(
self,
approval_id: str | None = None,
order_job: dict[str, Any] | None = None,
recommendation_id: str | None = None,
ticker: str = "",
side: str = "buy",
quantity: float = 0.0,
estimated_value: float = 0.0,
risk_evaluation_id: str | None = None,
status: ApprovalStatus = ApprovalStatus.PENDING,
requested_by: str = "system",
reviewed_by: str | None = None,
review_note: str | None = None,
expires_at: datetime | None = None,
requested_at: datetime | None = None,
reviewed_at: datetime | None = None,
) -> None:
self.approval_id = approval_id or str(uuid.uuid4())
self.order_job = order_job or {}
self.recommendation_id = recommendation_id
self.ticker = ticker
self.side = side
self.quantity = quantity
self.estimated_value = estimated_value
self.risk_evaluation_id = risk_evaluation_id
self.status = status
self.requested_by = requested_by
self.reviewed_by = reviewed_by
self.review_note = review_note
self.expires_at = expires_at or (datetime.now(timezone.utc) + timedelta(minutes=30))
self.requested_at = requested_at or datetime.now(timezone.utc)
self.reviewed_at = reviewed_at
@property
def is_pending(self) -> bool:
return self.status == ApprovalStatus.PENDING
@property
def is_expired(self) -> bool:
if self.status == ApprovalStatus.EXPIRED:
return True
if self.status == ApprovalStatus.PENDING:
return datetime.now(timezone.utc) >= self.expires_at
return False
def to_dict(self) -> dict[str, Any]:
return {
"approval_id": self.approval_id,
"recommendation_id": self.recommendation_id,
"ticker": self.ticker,
"side": self.side,
"quantity": self.quantity,
"estimated_value": self.estimated_value,
"risk_evaluation_id": self.risk_evaluation_id,
"status": self.status.value,
"requested_by": self.requested_by,
"reviewed_by": self.reviewed_by,
"review_note": self.review_note,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"requested_at": self.requested_at.isoformat() if self.requested_at else None,
"reviewed_at": self.reviewed_at.isoformat() if self.reviewed_at else None,
}
# ---------------------------------------------------------------------------
# DB persistence
# ---------------------------------------------------------------------------
_INSERT_APPROVAL = """
INSERT INTO operator_approvals (
id, order_job, recommendation_id, ticker, side, quantity,
estimated_value, status, risk_evaluation_id, requested_by,
expires_at, requested_at
) VALUES (
$1::uuid, $2::jsonb, $3, $4, $5, $6,
$7, $8, $9, $10,
$11, $12
)
"""
_UPDATE_APPROVAL_STATUS = """
UPDATE operator_approvals
SET status = $2, reviewed_by = $3, review_note = $4, reviewed_at = $5, updated_at = NOW()
WHERE id = $1::uuid AND status = 'pending'
RETURNING id, status
"""
_EXPIRE_STALE_APPROVALS = """
UPDATE operator_approvals
SET status = 'expired', updated_at = NOW()
WHERE status = 'pending' AND expires_at <= $1
RETURNING id, ticker
"""
_FETCH_PENDING_APPROVALS = """
SELECT id, order_job, recommendation_id, ticker, side, quantity,
estimated_value, status, risk_evaluation_id, requested_by,
reviewed_by, review_note, expires_at, requested_at, reviewed_at
FROM operator_approvals
WHERE status = 'pending'
ORDER BY requested_at ASC
"""
_FETCH_APPROVAL_BY_ID = """
SELECT id, order_job, recommendation_id, ticker, side, quantity,
estimated_value, status, risk_evaluation_id, requested_by,
reviewed_by, review_note, expires_at, requested_at, reviewed_at
FROM operator_approvals
WHERE id = $1::uuid
"""
def _row_to_request(row: Any) -> ApprovalRequest:
"""Convert a DB row to an ApprovalRequest."""
order_job = row["order_job"]
if isinstance(order_job, str):
order_job = json.loads(order_job)
return ApprovalRequest(
approval_id=str(row["id"]),
order_job=order_job,
recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else None,
ticker=row["ticker"],
side=row["side"],
quantity=float(row["quantity"]),
estimated_value=float(row["estimated_value"]),
risk_evaluation_id=str(row["risk_evaluation_id"]) if row.get("risk_evaluation_id") else None,
status=ApprovalStatus(row["status"]),
requested_by=row["requested_by"],
reviewed_by=row["reviewed_by"],
review_note=row["review_note"],
expires_at=row["expires_at"],
requested_at=row["requested_at"],
reviewed_at=row["reviewed_at"],
)
async def create_approval_request(
pool: asyncpg.Pool,
request: ApprovalRequest,
) -> str:
"""Persist a new approval request. Returns the approval ID."""
await pool.execute(
_INSERT_APPROVAL,
request.approval_id,
json.dumps(request.order_job, default=str),
request.recommendation_id,
request.ticker,
request.side,
request.quantity,
request.estimated_value,
request.status.value,
request.risk_evaluation_id,
request.requested_by,
request.expires_at,
request.requested_at,
)
return request.approval_id
async def review_approval(
pool: asyncpg.Pool,
approval_id: str,
approved: bool,
reviewed_by: str = "operator",
review_note: str = "",
) -> ApprovalStatus | None:
"""Approve or reject a pending approval request.
Returns the new status, or None if the approval was not found
or was no longer pending (already expired/reviewed).
"""
now = datetime.now(timezone.utc)
new_status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED
row = await pool.fetchrow(
_UPDATE_APPROVAL_STATUS,
approval_id,
new_status.value,
reviewed_by,
review_note,
now,
)
if row:
return ApprovalStatus(row["status"])
return None
async def expire_stale_approvals(
pool: asyncpg.Pool,
now: datetime | None = None,
) -> list[dict[str, str]]:
"""Mark all expired pending approvals. Returns list of expired items."""
now = now or datetime.now(timezone.utc)
rows = await pool.fetch(_EXPIRE_STALE_APPROVALS, now)
return [{"id": str(r["id"]), "ticker": r["ticker"]} for r in rows]
async def get_pending_approvals(
pool: asyncpg.Pool,
) -> list[ApprovalRequest]:
"""Fetch all pending approval requests, oldest first."""
rows = await pool.fetch(_FETCH_PENDING_APPROVALS)
return [_row_to_request(r) for r in rows]
async def get_approval_by_id(
pool: asyncpg.Pool,
approval_id: str,
) -> ApprovalRequest | None:
"""Fetch a single approval request by ID."""
row = await pool.fetchrow(_FETCH_APPROVAL_BY_ID, approval_id)
if row:
return _row_to_request(row)
return None
+616 -1
View File
@@ -1 +1,616 @@
"""Risk engine - enforces guardrails, position limits, and trade eligibility checks."""
"""Risk engine - portfolio and account risk configuration and enforcement.
Defines the configuration and state models used to enforce guardrails
on trade execution: max position size, sector exposure, daily loss limits,
news-shock lockouts, and operator approval rules.
Also implements the hard-block evaluation logic that decides whether a
proposed order is allowed before it reaches the broker adapter.
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
Design: Section 4.8 - Risk Engine
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class TradingMode(str, Enum):
"""Execution environment separation (Requirement 8.1)."""
PAPER = "paper"
LIVE = "live"
DISABLED = "disabled"
class RiskCheckResult(str, Enum):
"""Outcome of a single risk check."""
PASS = "pass"
FAIL = "fail"
WARN = "warn"
# ---------------------------------------------------------------------------
# Portfolio-level risk configuration (Requirement 8.2, 8.4)
# ---------------------------------------------------------------------------
class PositionLimits(BaseModel):
"""Per-position size constraints."""
max_position_pct: float = Field(
default=0.05, ge=0, le=1,
description="Maximum portfolio percentage for a single position",
)
max_position_value: float = Field(
default=10_000.0, ge=0,
description="Maximum dollar value for a single position",
)
max_shares_per_order: float = Field(
default=1000.0, ge=0,
description="Maximum shares in a single order",
)
class SectorExposureLimits(BaseModel):
"""Sector-level concentration limits."""
max_sector_pct: float = Field(
default=0.25, ge=0, le=1,
description="Maximum portfolio percentage exposed to one sector",
)
max_sectors: int = Field(
default=10, ge=1,
description="Maximum number of sectors with open positions",
)
class DailyLossLimits(BaseModel):
"""Daily drawdown controls."""
max_daily_loss_pct: float = Field(
default=0.02, ge=0, le=1,
description="Maximum portfolio loss percentage in a single day before halting",
)
max_daily_loss_value: float = Field(
default=1_000.0, ge=0,
description="Maximum dollar loss in a single day before halting",
)
max_daily_trades: int = Field(
default=20, ge=0,
description="Maximum number of trades per day",
)
class NewsShockLockout(BaseModel):
"""News-shock lockout configuration.
When a symbol has a high-impact news event, trading is paused
for a configurable cooldown period.
"""
enabled: bool = True
lockout_minutes: int = Field(
default=60, ge=0,
description="Minutes to lock out trading after a high-impact news event",
)
impact_threshold: float = Field(
default=0.80, ge=0, le=1,
description="Minimum impact_score from document intelligence to trigger lockout",
)
catalyst_types: list[str] = Field(
default_factory=lambda: ["earnings", "legal", "m_and_a"],
description="Catalyst types that trigger lockout when above threshold",
)
class OperatorApproval(BaseModel):
"""Operator approval workflow for live trading (Requirement 8.2)."""
require_approval_for_live: bool = Field(
default=True,
description="Whether live orders require operator approval",
)
auto_approve_paper: bool = Field(
default=True,
description="Whether paper orders are auto-approved",
)
approval_timeout_minutes: int = Field(
default=30, ge=1,
description="Minutes before a pending approval expires",
)
class SymbolCooldown(BaseModel):
"""Per-symbol cooldown after a trade."""
cooldown_minutes: int = Field(
default=15, ge=0,
description="Minutes to wait before trading the same symbol again",
)
max_open_positions_per_symbol: int = Field(
default=1, ge=1,
description="Maximum concurrent open positions for a single symbol",
)
class PortfolioRiskConfig(BaseModel):
"""Complete portfolio-level risk configuration.
This is the top-level config that governs all risk checks.
Persisted in PostgreSQL and loaded at engine startup.
"""
config_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str = "default"
trading_mode: TradingMode = TradingMode.PAPER
position_limits: PositionLimits = Field(default_factory=PositionLimits)
sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits)
daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits)
news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout)
operator_approval: OperatorApproval = Field(default_factory=OperatorApproval)
symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown)
active: bool = True
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
def to_db_json(self) -> dict[str, Any]:
"""Serialize the full config to a JSON-compatible dict for DB storage."""
return self.model_dump(mode="json")
@classmethod
def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig:
"""Deserialize from a DB JSON column."""
return cls.model_validate(data)
# ---------------------------------------------------------------------------
# Account risk state (runtime snapshot)
# ---------------------------------------------------------------------------
class AccountRiskState(BaseModel):
"""Runtime snapshot of an account's risk posture.
Computed from broker positions, today's trades, and current P&L.
Used by risk checks to evaluate whether a new order is allowed.
"""
account_id: str = ""
portfolio_value: float = 0.0
cash: float = 0.0
buying_power: float = 0.0
daily_pnl: float = 0.0
daily_trade_count: int = 0
open_position_count: int = 0
positions_by_symbol: dict[str, float] = Field(
default_factory=dict,
description="Map of ticker → current market value",
)
positions_by_sector: dict[str, float] = Field(
default_factory=dict,
description="Map of sector → total market value",
)
last_trade_times: dict[str, datetime] = Field(
default_factory=dict,
description="Map of ticker → last trade timestamp for cooldown checks",
)
locked_symbols: dict[str, datetime] = Field(
default_factory=dict,
description="Map of ticker → lockout expiry for news-shock lockouts",
)
snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Risk check output (Requirement 8.3 - full decision trace)
# ---------------------------------------------------------------------------
class RiskCheckDetail(BaseModel):
"""Result of a single risk check."""
check_name: str
result: RiskCheckResult
message: str = ""
threshold: float | None = None
actual: float | None = None
class RiskEvaluation(BaseModel):
"""Complete risk evaluation for a proposed order.
Captures every check performed so the full decision trace
is reproducible (Requirement 8.3).
"""
evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
recommendation_id: str | None = None
ticker: str = ""
eligible: bool = False
allowed_mode: TradingMode = TradingMode.DISABLED
checks: list[RiskCheckDetail] = Field(default_factory=list)
rejection_reasons: list[str] = Field(default_factory=list)
config_snapshot: PortfolioRiskConfig | None = None
state_snapshot: AccountRiskState | None = None
evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@property
def passed(self) -> bool:
return self.eligible and len(self.rejection_reasons) == 0
# ---------------------------------------------------------------------------
# Default configuration
# ---------------------------------------------------------------------------
DEFAULT_RISK_CONFIG = PortfolioRiskConfig()
# ---------------------------------------------------------------------------
# Proposed order (input to risk evaluation)
# ---------------------------------------------------------------------------
class ProposedOrder(BaseModel):
"""A proposed order to be evaluated by the risk engine before submission.
This is the input to evaluate_order(). It carries enough context
for every risk check to run without external lookups.
"""
recommendation_id: str | None = None
ticker: str
sector: str = ""
action: str = "buy" # buy | sell
quantity: float = 0.0
estimated_value: float = 0.0
confidence: float = 0.0
# ---------------------------------------------------------------------------
# Individual risk checks (Requirement 8.4)
# ---------------------------------------------------------------------------
def _check_trading_mode(
config: PortfolioRiskConfig,
) -> RiskCheckDetail:
"""Block all orders when trading is disabled."""
if config.trading_mode == TradingMode.DISABLED:
return RiskCheckDetail(
check_name="trading_mode",
result=RiskCheckResult.FAIL,
message="Trading is disabled",
)
return RiskCheckDetail(
check_name="trading_mode",
result=RiskCheckResult.PASS,
message=f"Trading mode: {config.trading_mode.value}",
)
def _check_max_position_size(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> list[RiskCheckDetail]:
"""Enforce per-position size limits (value, percentage, shares)."""
checks: list[RiskCheckDetail] = []
limits = config.position_limits
# Check max position value
existing_value = state.positions_by_symbol.get(order.ticker, 0.0)
new_total_value = existing_value + order.estimated_value
checks.append(RiskCheckDetail(
check_name="max_position_value",
result=(
RiskCheckResult.PASS
if new_total_value <= limits.max_position_value
else RiskCheckResult.FAIL
),
message=(
f"Position value {new_total_value:.2f} "
f"{'within' if new_total_value <= limits.max_position_value else 'exceeds'} "
f"limit {limits.max_position_value:.2f}"
),
threshold=limits.max_position_value,
actual=new_total_value,
))
# Check max position percentage of portfolio
if state.portfolio_value > 0:
position_pct = new_total_value / state.portfolio_value
else:
position_pct = 1.0 if new_total_value > 0 else 0.0
checks.append(RiskCheckDetail(
check_name="max_position_pct",
result=(
RiskCheckResult.PASS
if position_pct <= limits.max_position_pct
else RiskCheckResult.FAIL
),
message=(
f"Position {position_pct:.4f} of portfolio "
f"{'within' if position_pct <= limits.max_position_pct else 'exceeds'} "
f"limit {limits.max_position_pct:.4f}"
),
threshold=limits.max_position_pct,
actual=position_pct,
))
# Check max shares per order
checks.append(RiskCheckDetail(
check_name="max_shares_per_order",
result=(
RiskCheckResult.PASS
if order.quantity <= limits.max_shares_per_order
else RiskCheckResult.FAIL
),
message=(
f"Order quantity {order.quantity:.0f} "
f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} "
f"limit {limits.max_shares_per_order:.0f}"
),
threshold=limits.max_shares_per_order,
actual=order.quantity,
))
return checks
def _check_sector_exposure(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> RiskCheckDetail:
"""Enforce sector concentration limits."""
limits = config.sector_exposure
if not order.sector:
return RiskCheckDetail(
check_name="sector_exposure",
result=RiskCheckResult.WARN,
message="No sector provided on order; skipping sector check",
)
existing_sector_value = state.positions_by_sector.get(order.sector, 0.0)
new_sector_value = existing_sector_value + order.estimated_value
if state.portfolio_value > 0:
sector_pct = new_sector_value / state.portfolio_value
else:
sector_pct = 1.0 if new_sector_value > 0 else 0.0
return RiskCheckDetail(
check_name="sector_exposure",
result=(
RiskCheckResult.PASS
if sector_pct <= limits.max_sector_pct
else RiskCheckResult.FAIL
),
message=(
f"Sector '{order.sector}' exposure {sector_pct:.4f} "
f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} "
f"limit {limits.max_sector_pct:.4f}"
),
threshold=limits.max_sector_pct,
actual=sector_pct,
)
def _check_daily_loss(
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> list[RiskCheckDetail]:
"""Enforce daily loss and trade count limits."""
checks: list[RiskCheckDetail] = []
limits = config.daily_loss
# Daily loss percentage
if state.portfolio_value > 0:
loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value
else:
loss_pct = 0.0
checks.append(RiskCheckDetail(
check_name="daily_loss_pct",
result=(
RiskCheckResult.PASS
if loss_pct <= limits.max_daily_loss_pct
else RiskCheckResult.FAIL
),
message=(
f"Daily loss {loss_pct:.4f} "
f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} "
f"limit {limits.max_daily_loss_pct:.4f}"
),
threshold=limits.max_daily_loss_pct,
actual=loss_pct,
))
# Daily loss absolute value
abs_loss = abs(min(state.daily_pnl, 0.0))
checks.append(RiskCheckDetail(
check_name="daily_loss_value",
result=(
RiskCheckResult.PASS
if abs_loss <= limits.max_daily_loss_value
else RiskCheckResult.FAIL
),
message=(
f"Daily loss ${abs_loss:.2f} "
f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} "
f"limit ${limits.max_daily_loss_value:.2f}"
),
threshold=limits.max_daily_loss_value,
actual=abs_loss,
))
# Daily trade count
checks.append(RiskCheckDetail(
check_name="daily_trade_count",
result=(
RiskCheckResult.PASS
if state.daily_trade_count < limits.max_daily_trades
else RiskCheckResult.FAIL
),
message=(
f"Daily trades {state.daily_trade_count} "
f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} "
f"limit {limits.max_daily_trades}"
),
threshold=float(limits.max_daily_trades),
actual=float(state.daily_trade_count),
))
return checks
def _check_news_shock_lockout(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
now: datetime | None = None,
) -> RiskCheckDetail:
"""Block trading on symbols under news-shock lockout."""
lockout_cfg = config.news_shock
if not lockout_cfg.enabled:
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.PASS,
message="News-shock lockout is disabled",
)
now = now or datetime.now(timezone.utc)
lockout_expiry = state.locked_symbols.get(order.ticker)
if lockout_expiry is not None and now < lockout_expiry:
remaining = lockout_expiry - now
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.FAIL,
message=(
f"Symbol {order.ticker} locked out until "
f"{lockout_expiry.isoformat()} "
f"({remaining.total_seconds():.0f}s remaining)"
),
)
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.PASS,
message=f"No active lockout for {order.ticker}",
)
def _check_symbol_cooldown(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
now: datetime | None = None,
) -> RiskCheckDetail:
"""Enforce per-symbol cooldown between trades."""
cooldown_cfg = config.symbol_cooldown
now = now or datetime.now(timezone.utc)
last_trade = state.last_trade_times.get(order.ticker)
if last_trade is not None:
cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes)
if now < cooldown_end:
remaining = cooldown_end - now
return RiskCheckDetail(
check_name="symbol_cooldown",
result=RiskCheckResult.FAIL,
message=(
f"Symbol {order.ticker} in cooldown until "
f"{cooldown_end.isoformat()} "
f"({remaining.total_seconds():.0f}s remaining)"
),
)
return RiskCheckDetail(
check_name="symbol_cooldown",
result=RiskCheckResult.PASS,
message=f"No active cooldown for {order.ticker}",
)
# ---------------------------------------------------------------------------
# Main evaluation entry point (Requirements 8.3, 8.4, 8.5)
# ---------------------------------------------------------------------------
def evaluate_order(
order: ProposedOrder,
config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG,
state: AccountRiskState | None = None,
now: datetime | None = None,
) -> RiskEvaluation:
"""Evaluate a proposed order against all risk controls.
Runs every hard-block check and returns a RiskEvaluation capturing
the full decision trace (Requirement 8.3). If any check fails,
the order is rejected before broker submission (Requirement 8.4).
The engine fails closed: if state is missing or ambiguous, the
order is rejected (Requirement 8.5).
"""
state = state or AccountRiskState()
now = now or datetime.now(timezone.utc)
all_checks: list[RiskCheckDetail] = []
rejection_reasons: list[str] = []
# 1. Trading mode gate
mode_check = _check_trading_mode(config)
all_checks.append(mode_check)
if mode_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(mode_check.message)
# 2. Position size limits
position_checks = _check_max_position_size(order, config, state)
all_checks.extend(position_checks)
for c in position_checks:
if c.result == RiskCheckResult.FAIL:
rejection_reasons.append(c.message)
# 3. Sector exposure
sector_check = _check_sector_exposure(order, config, state)
all_checks.append(sector_check)
if sector_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(sector_check.message)
# 4. Daily loss limits
daily_checks = _check_daily_loss(config, state)
all_checks.extend(daily_checks)
for c in daily_checks:
if c.result == RiskCheckResult.FAIL:
rejection_reasons.append(c.message)
# 5. News-shock lockout
lockout_check = _check_news_shock_lockout(order, config, state, now)
all_checks.append(lockout_check)
if lockout_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(lockout_check.message)
# 6. Symbol cooldown
cooldown_check = _check_symbol_cooldown(order, config, state, now)
all_checks.append(cooldown_check)
if cooldown_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(cooldown_check.message)
# Determine eligibility and allowed mode
eligible = len(rejection_reasons) == 0
allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED
return RiskEvaluation(
recommendation_id=order.recommendation_id,
ticker=order.ticker,
eligible=eligible,
allowed_mode=allowed_mode,
checks=all_checks,
rejection_reasons=rejection_reasons,
config_snapshot=config,
state_snapshot=state,
evaluated_at=now,
)
+238 -43
View File
@@ -1,14 +1,23 @@
"""Scheduler - triggers ingestion cycles for tracked symbols and sources."""
"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
Polls the symbol registry for active companies and their configured sources,
respects per-source polling cadences and backoff windows, coordinates rate
limits across source types, and enqueues ingestion jobs for downstream workers.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
import asyncio
import json
import logging
from datetime import datetime, timedelta
from datetime import datetime
from typing import Any, Optional
import asyncpg
import redis.asyncio as aioredis
from services.shared.config import load_config
from services.shared.db import get_pg_pool, get_redis
from services.shared.logging import setup_logging
from services.shared.redis_keys import (
QUEUE_INGESTION,
lock_key,
@@ -16,11 +25,11 @@ from services.shared.redis_keys import (
rate_limit_key,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("scheduler")
# Polling cadences by source class (seconds)
CADENCES = {
# Default polling cadences by source class (seconds).
# Individual sources can override via config.polling_interval_seconds.
DEFAULT_CADENCES: dict[str, int] = {
"market_api": 60,
"news_api": 300,
"filings_api": 3600,
@@ -28,81 +37,267 @@ CADENCES = {
"broker": 30,
}
# Default rate limits per source type (requests per minute)
DEFAULT_RATE_LIMITS: dict[str, int] = {
"market_api": 30,
"news_api": 20,
"filings_api": 10,
"web_scrape": 10,
"broker": 60,
}
# How long to wait before retrying a failed source (seconds)
DEFAULT_BACKOFF_BASE: int = 60
MAX_BACKOFF: int = 3600
MAX_RETRY_COUNT: int = 10
# Main loop interval (seconds)
SCHEDULER_TICK: int = 15
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
"""Return the polling interval for a source.
Uses the source's config.polling_interval_seconds if set,
otherwise falls back to the default cadence for the source type.
"""
if config and "polling_interval_seconds" in config:
try:
return max(10, int(config["polling_interval_seconds"]))
except (ValueError, TypeError):
pass
return DEFAULT_CADENCES.get(source_type, 600)
def compute_backoff(retry_count: int) -> int:
"""Exponential backoff with a cap. Returns seconds to wait."""
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
return min(delay, MAX_BACKOFF)
def is_source_due(
source_type: str,
source_config: Optional[dict[str, Any]],
last_completed_at: Optional[datetime],
last_status: Optional[str],
retry_count: int,
next_retry_at: Optional[datetime],
now: datetime,
) -> bool:
"""Determine whether a source is due for its next polling cycle.
Checks:
- If the source has never run, it is due.
- If the last run failed and we have a next_retry_at in the future, skip.
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
- Otherwise, check if enough time has elapsed since the last completed run.
"""
# Never run before — always due
if last_completed_at is None and last_status is None:
return True
# If last run failed, respect backoff
if last_status == "failed":
if retry_count >= MAX_RETRY_COUNT:
return False
if next_retry_at and now < next_retry_at.replace(tzinfo=None):
return False
# Backoff elapsed or no next_retry_at set — allow retry
return True
# If currently running, don't double-schedule
if last_status == "running":
return False
# Normal cadence check
if last_completed_at is None:
return True
cadence = get_cadence_for_source(source_type, source_config)
elapsed = (now - last_completed_at.replace(tzinfo=None)).total_seconds()
return elapsed >= cadence
def build_job_payload(
source: Any,
aliases: list[str],
now: datetime,
) -> dict[str, Any]:
"""Build the ingestion job payload for a source."""
return {
"source_id": str(source["source_id"]),
"company_id": str(source["company_id"]),
"ticker": source["ticker"],
"legal_name": source["legal_name"],
"aliases": aliases,
"source_type": source["source_type"],
"source_name": source["source_name"],
"config": dict(source["config"]) if source["config"] else {},
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
"scheduled_at": now.isoformat(),
}
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
"""Acquire a distributed lock. Returns True if acquired."""
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
async def release_lock(rds: aioredis.Redis, name: str):
async def release_lock(rds: aioredis.Redis, name: str) -> None:
"""Release a distributed lock."""
await rds.delete(lock_key(name))
async def check_rate_limit(rds: aioredis.Redis, source_type: str, max_per_minute: int = 30) -> bool:
key = rate_limit_key(source_type, datetime.utcnow().strftime("%Y%m%d%H%M"))
async def check_rate_limit(
rds: aioredis.Redis,
source_type: str,
now: datetime,
max_per_minute: Optional[int] = None,
) -> bool:
"""Check whether the source type is within its rate limit window.
Returns True if the request is allowed, False if rate-limited.
"""
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
window = now.strftime("%Y%m%d%H%M")
key = rate_limit_key(source_type, window)
count = await rds.incr(key)
if count == 1:
await rds.expire(key, 120)
return count <= max_per_minute
return count <= limit
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis):
"""One scheduling pass: find due sources and enqueue ingestion jobs."""
sources = await pool.fetch(
"""SELECT s.id as source_id, s.company_id, s.source_type, s.source_name, s.config,
c.ticker, c.legal_name
FROM sources s JOIN companies c ON s.company_id = c.id
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active sources joined with their active companies."""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score,
c.ticker,
c.legal_name
FROM sources s
JOIN companies c ON s.company_id = c.id
WHERE s.active = TRUE AND c.active = TRUE
ORDER BY s.source_type, c.ticker"""
)
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
"""Fetch all aliases for a company."""
rows = await pool.fetch(
"SELECT alias FROM company_aliases WHERE company_id = $1",
company_id,
)
return [r["alias"] for r in rows]
async def fetch_last_run(
pool: asyncpg.Pool, source_id: str
) -> Optional[asyncpg.Record]:
"""Fetch the most recent ingestion run for a source."""
return await pool.fetchrow(
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
FROM ingestion_runs
WHERE source_id = $1
ORDER BY started_at DESC
LIMIT 1""",
source_id,
)
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""One scheduling pass: find due sources and enqueue ingestion jobs.
Returns the number of jobs enqueued.
"""
now = datetime.utcnow()
sources = await fetch_active_sources(pool)
enqueued = 0
skipped_rate_limit = 0
skipped_not_due = 0
for src in sources:
source_id = src["source_id"]
source_type = src["source_type"]
cadence = CADENCES.get(source_type, 600)
source_config = dict(src["config"]) if src["config"] else None
# Check last run
last_run = await pool.fetchval(
"SELECT MAX(started_at) FROM ingestion_runs WHERE source_id = $1 AND status IN ('completed', 'running')",
src["source_id"],
)
if last_run and (datetime.utcnow() - last_run.replace(tzinfo=None)) < timedelta(seconds=cadence):
# Check last run status and timing
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
if not await check_rate_limit(rds, source_type):
logger.warning(f"Rate limit hit for {source_type}")
# Rate limit check
if not await check_rate_limit(rds, source_type, now):
logger.warning(
"Rate limit hit for %s, skipping %s/%s",
source_type, src["ticker"], src["source_name"],
)
skipped_rate_limit += 1
continue
job = {
"source_id": str(src["source_id"]),
"company_id": str(src["company_id"]),
"ticker": src["ticker"],
"source_type": source_type,
"source_name": src["source_name"],
"config": dict(src["config"]) if src["config"] else {},
"scheduled_at": datetime.utcnow().isoformat(),
}
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
# Fetch company aliases for downstream entity matching
aliases = await fetch_aliases_for_company(pool, src["company_id"])
job = build_job_payload(src, aliases, now)
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
enqueued += 1
if enqueued:
logger.info(f"Enqueued {enqueued} ingestion jobs")
logger.debug(
"Enqueued %s job for %s (%s)",
source_type, src["ticker"], src["source_name"],
)
logger.info(
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
enqueued, skipped_not_due, skipped_rate_limit, len(sources),
)
return enqueued
async def main():
async def main() -> None:
config = load_config()
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
rds = get_redis(config)
logger.info("Scheduler started")
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
try:
while True:
try:
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
await schedule_cycle(pool, rds)
await release_lock(rds, "scheduler_cycle")
except Exception as e:
logger.error(f"Scheduler cycle error: {e}")
await asyncio.sleep(15)
try:
await schedule_cycle(pool, rds)
finally:
await release_lock(rds, "scheduler_cycle")
except Exception:
logger.exception("Scheduler cycle error")
await asyncio.sleep(SCHEDULER_TICK)
finally:
await pool.close()
await rds.close()
+342
View File
@@ -0,0 +1,342 @@
"""Operational alerting for Stonks Oracle pipeline health.
Evaluates alert rules against PostgreSQL operational state and emits
structured log events and Prometheus metrics when thresholds are breached.
Alert rules:
- source_failures: sustained source retrieval failures per source
- schema_failure_spike: extraction validation failure rate exceeds threshold
- analytical_lag: lake publication has not completed within threshold
- broker_issues: consecutive broker submission errors
Requirements: 12.3
Design: Section 12 (Observability and Operations)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import asyncpg
from services.shared.config import AlertingConfig
from services.shared.metrics import (
ALERT_ACTIVE,
ALERT_CHECK_DURATION,
ALERTS_FIRED,
ALERTS_RESOLVED,
)
logger = logging.getLogger("alerting")
@dataclass
class Alert:
"""A single alert instance."""
rule: str
severity: str # "warning" | "critical"
summary: str
details: dict[str, Any] = field(default_factory=dict)
fired_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class AlertState:
"""Tracks which rules are currently firing to detect transitions."""
active: dict[str, Alert] = field(default_factory=dict)
def fire(self, alert: Alert) -> bool:
"""Record an alert firing. Returns True if this is a new firing."""
key = f"{alert.rule}:{alert.details.get('key', '')}"
is_new = key not in self.active
self.active[key] = alert
return is_new
def resolve(self, rule: str, key: str = "") -> bool:
"""Resolve an alert. Returns True if it was previously active."""
full_key = f"{rule}:{key}"
if full_key in self.active:
del self.active[full_key]
return True
return False
def is_firing(self, rule: str, key: str = "") -> bool:
return f"{rule}:{key}" in self.active
async def check_source_failures(
pool: asyncpg.Pool,
config: AlertingConfig,
) -> list[Alert]:
"""Check for sources with sustained consecutive failures.
Queries ingestion_runs for sources where the last N runs all failed
within the lookback window.
"""
rows = await pool.fetch(
"""WITH recent_runs AS (
SELECT source_id, status,
ROW_NUMBER() OVER (PARTITION BY source_id ORDER BY started_at DESC) AS rn
FROM ingestion_runs
WHERE started_at >= NOW() - INTERVAL '1 hour' * $1
),
failure_streaks AS (
SELECT source_id,
COUNT(*) FILTER (WHERE status = 'failed') AS consecutive_failures,
COUNT(*) AS total_runs
FROM recent_runs
WHERE rn <= $2
GROUP BY source_id
HAVING COUNT(*) FILTER (WHERE status = 'failed') = COUNT(*)
AND COUNT(*) >= $2
)
SELECT fs.source_id, fs.consecutive_failures,
s.source_type, s.source_name, c.ticker
FROM failure_streaks fs
JOIN sources s ON s.id = fs.source_id
JOIN companies c ON c.id = s.company_id""",
config.source_failure_window_hours,
config.source_failure_threshold,
)
alerts = []
for row in rows:
alerts.append(Alert(
rule="source_failures",
severity="warning",
summary=(
f"Source {row['source_name']} ({row['source_type']}) for "
f"{row['ticker']} has {row['consecutive_failures']} consecutive failures"
),
details={
"key": str(row["source_id"]),
"source_id": str(row["source_id"]),
"source_type": row["source_type"],
"source_name": row["source_name"],
"ticker": row["ticker"],
"consecutive_failures": row["consecutive_failures"],
},
))
return alerts
async def check_schema_failure_spike(
pool: asyncpg.Pool,
config: AlertingConfig,
) -> list[Alert]:
"""Check if extraction schema validation failure rate exceeds threshold.
Queries model_performance_metrics for the recent window and computes
the failure rate.
"""
row = await pool.fetchrow(
"""SELECT
COUNT(*) AS total,
COUNT(*) FILTER (WHERE NOT success) AS failed
FROM model_performance_metrics
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1""",
config.schema_failure_window_hours,
)
if not row or row["total"] == 0:
return []
total = row["total"]
failed = row["failed"]
failure_rate = failed / total
if failure_rate >= config.schema_failure_rate_threshold:
return [Alert(
rule="schema_failure_spike",
severity="critical" if failure_rate >= 0.5 else "warning",
summary=(
f"Extraction schema failure rate is {failure_rate:.1%} "
f"({failed}/{total}) in the last {config.schema_failure_window_hours}h"
),
details={
"key": "global",
"total_extractions": total,
"failed_extractions": failed,
"failure_rate": round(failure_rate, 4),
"threshold": config.schema_failure_rate_threshold,
"window_hours": config.schema_failure_window_hours,
},
)]
return []
async def check_analytical_lag(
pool: asyncpg.Pool,
config: AlertingConfig,
) -> list[Alert]:
"""Check if lake publication is lagging beyond threshold.
Looks at the audit_events table for the most recent successful
lake_publish events per table, and alerts if any are stale.
"""
rows = await pool.fetch(
"""SELECT
details->>'table_name' AS table_name,
MAX(created_at) AS last_publish
FROM audit_events
WHERE event_type = 'lake_publish'
AND details->>'status' = 'success'
AND details->>'table_name' IS NOT NULL
GROUP BY details->>'table_name'
HAVING MAX(created_at) < NOW() - INTERVAL '1 minute' * $1""",
config.lake_lag_threshold_minutes,
)
alerts = []
now = datetime.now(timezone.utc)
for row in rows:
table_name = row["table_name"]
last_publish = row["last_publish"]
if last_publish.tzinfo is None:
last_publish = last_publish.replace(tzinfo=timezone.utc)
lag_minutes = (now - last_publish).total_seconds() / 60
alerts.append(Alert(
rule="analytical_lag",
severity="warning",
summary=(
f"Lake table '{table_name}' last published {lag_minutes:.0f}m ago "
f"(threshold: {config.lake_lag_threshold_minutes}m)"
),
details={
"key": table_name,
"table_name": table_name,
"last_publish": last_publish.isoformat(),
"lag_minutes": round(lag_minutes, 1),
"threshold_minutes": config.lake_lag_threshold_minutes,
},
))
return alerts
async def check_broker_issues(
pool: asyncpg.Pool,
config: AlertingConfig,
) -> list[Alert]:
"""Check for consecutive broker submission errors.
Queries order_events for recent broker-level errors (rejections,
timeouts, connection failures) within the lookback window.
"""
rows = await pool.fetch(
"""WITH recent_events AS (
SELECT order_id, event_type, created_at,
ROW_NUMBER() OVER (ORDER BY created_at DESC) AS rn
FROM order_events
WHERE created_at >= NOW() - INTERVAL '1 hour' * $1
AND event_type IN ('broker_error', 'broker_timeout', 'connection_failed')
)
SELECT COUNT(*) AS error_count
FROM recent_events
WHERE rn <= $2""",
config.broker_error_window_hours,
config.broker_error_threshold,
)
if not rows:
return []
error_count = rows[0]["error_count"]
if error_count >= config.broker_error_threshold:
return [Alert(
rule="broker_issues",
severity="critical",
summary=(
f"{error_count} broker errors in the last "
f"{config.broker_error_window_hours}h"
),
details={
"key": "global",
"error_count": error_count,
"threshold": config.broker_error_threshold,
"window_hours": config.broker_error_window_hours,
},
)]
return []
async def evaluate_alerts(
pool: asyncpg.Pool,
config: AlertingConfig,
state: AlertState,
) -> list[Alert]:
"""Run all alert rules and return newly fired alerts.
Updates AlertState to track firing/resolved transitions and emits
structured log events and Prometheus metrics for each transition.
"""
all_alerts: list[Alert] = []
with ALERT_CHECK_DURATION.time():
# Collect alerts from all rules
try:
all_alerts.extend(await check_source_failures(pool, config))
except Exception:
logger.exception("Error checking source failures")
try:
all_alerts.extend(await check_schema_failure_spike(pool, config))
except Exception:
logger.exception("Error checking schema failure spike")
try:
all_alerts.extend(await check_analytical_lag(pool, config))
except Exception:
logger.exception("Error checking analytical lag")
try:
all_alerts.extend(await check_broker_issues(pool, config))
except Exception:
logger.exception("Error checking broker issues")
# Track which rule+key combos are currently firing
current_keys: set[str] = set()
newly_fired: list[Alert] = []
for alert in all_alerts:
key = f"{alert.rule}:{alert.details.get('key', '')}"
current_keys.add(key)
if state.fire(alert):
# New alert firing
ALERTS_FIRED.labels(rule=alert.rule, severity=alert.severity).inc()
ALERT_ACTIVE.labels(rule=alert.rule).set(1)
newly_fired.append(alert)
logger.warning(
"ALERT FIRING: [%s] %s",
alert.rule,
alert.summary,
extra={
"alert_rule": alert.rule,
"alert_severity": alert.severity,
"alert_details": alert.details,
},
)
# Check for resolved alerts
resolved_keys = set(state.active.keys()) - current_keys
for key in resolved_keys:
rule = key.split(":")[0]
detail_key = key[len(rule) + 1:]
if state.resolve(rule, detail_key):
ALERTS_RESOLVED.labels(rule=rule).inc()
# Only set gauge to 0 if no more alerts for this rule
still_firing = any(k.startswith(f"{rule}:") for k in state.active)
if not still_firing:
ALERT_ACTIVE.labels(rule=rule).set(0)
logger.info(
"ALERT RESOLVED: [%s] key=%s",
rule,
detail_key,
)
return newly_fired
+493
View File
@@ -0,0 +1,493 @@
"""Execution audit trail - records every step from recommendation to market outcome.
Writes structured audit events to the audit_events table so the full
decision chain is traceable: recommendation → risk evaluation → order
submission → broker response → fill/rejection/cancellation.
Each event captures the entity type, entity ID, event type, actor,
and a JSONB data payload with stage-specific details.
Requirements: 8.3, 11.3
Design: Section 4.9 (Broker Adapter), Section 6.1 (PostgreSQL audit_events)
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
import asyncpg
logger = logging.getLogger("audit")
# ---------------------------------------------------------------------------
# Event type constants
# ---------------------------------------------------------------------------
# Recommendation stage
AUDIT_RECOMMENDATION_GENERATED = "recommendation.generated"
AUDIT_RECOMMENDATION_SUPPRESSED = "recommendation.suppressed"
# Risk evaluation stage
AUDIT_RISK_EVALUATED = "risk.evaluated"
AUDIT_RISK_REJECTED = "risk.rejected"
# Order lifecycle
AUDIT_ORDER_SUBMITTED = "order.submitted"
AUDIT_ORDER_ACCEPTED = "order.accepted"
AUDIT_ORDER_FILLED = "order.filled"
AUDIT_ORDER_REJECTED = "order.rejected"
AUDIT_ORDER_CANCELLED = "order.cancelled"
AUDIT_ORDER_DUPLICATE = "order.duplicate_prevented"
# Position changes
AUDIT_POSITION_OPENED = "position.opened"
AUDIT_POSITION_CLOSED = "position.closed"
AUDIT_POSITION_UPDATED = "position.updated"
# Trading mode changes
AUDIT_TRADING_MODE_CHANGED = "trading.mode_changed"
# Operator approval workflow
AUDIT_APPROVAL_REQUESTED = "approval.requested"
AUDIT_APPROVAL_APPROVED = "approval.approved"
AUDIT_APPROVAL_REJECTED = "approval.rejected"
AUDIT_APPROVAL_EXPIRED = "approval.expired"
# ---------------------------------------------------------------------------
# Core audit writer
# ---------------------------------------------------------------------------
_INSERT_AUDIT_EVENT = """
INSERT INTO audit_events (id, event_type, entity_type, entity_id, actor, data, created_at)
VALUES ($1::uuid, $2, $3, $4::uuid, $5, $6::jsonb, $7)
"""
async def record_audit_event(
pool: asyncpg.Pool,
event_type: str,
entity_type: str,
entity_id: str,
data: dict[str, Any],
actor: str = "system",
timestamp: datetime | None = None,
) -> str:
"""Write a single audit event to PostgreSQL.
Returns the audit event UUID.
"""
event_id = str(uuid.uuid4())
ts = timestamp or datetime.now(timezone.utc)
try:
await pool.execute(
_INSERT_AUDIT_EVENT,
event_id,
event_type,
entity_type,
entity_id,
actor,
json.dumps(data, default=str),
ts,
)
except Exception:
logger.warning(
"Failed to write audit event %s for %s/%s",
event_type, entity_type, entity_id,
exc_info=True,
)
return ""
return event_id
# ---------------------------------------------------------------------------
# Convenience helpers for each execution stage
# ---------------------------------------------------------------------------
async def audit_recommendation_generated(
pool: asyncpg.Pool,
recommendation_id: str,
ticker: str,
action: str,
mode: str,
confidence: float,
evidence_count: int,
suppressed: bool = False,
) -> str:
"""Record that a recommendation was generated."""
event_type = AUDIT_RECOMMENDATION_SUPPRESSED if suppressed else AUDIT_RECOMMENDATION_GENERATED
return await record_audit_event(
pool,
event_type=event_type,
entity_type="recommendation",
entity_id=recommendation_id,
data={
"ticker": ticker,
"action": action,
"mode": mode,
"confidence": confidence,
"evidence_count": evidence_count,
"suppressed": suppressed,
},
actor="recommendation_worker",
)
async def audit_risk_evaluated(
pool: asyncpg.Pool,
evaluation_id: str,
recommendation_id: str | None,
ticker: str,
eligible: bool,
allowed_mode: str,
rejection_reasons: list[str],
check_count: int,
) -> str:
"""Record a risk evaluation result."""
event_type = AUDIT_RISK_REJECTED if not eligible else AUDIT_RISK_EVALUATED
return await record_audit_event(
pool,
event_type=event_type,
entity_type="risk_evaluation",
entity_id=evaluation_id,
data={
"recommendation_id": recommendation_id,
"ticker": ticker,
"eligible": eligible,
"allowed_mode": allowed_mode,
"rejection_reasons": rejection_reasons,
"check_count": check_count,
},
actor="risk_engine",
)
async def audit_order_submitted(
pool: asyncpg.Pool,
order_id: str,
ticker: str,
side: str,
quantity: float,
order_type: str,
idempotency_key: str,
recommendation_id: str | None = None,
evaluation_id: str | None = None,
) -> str:
"""Record that an order was submitted to the broker."""
return await record_audit_event(
pool,
event_type=AUDIT_ORDER_SUBMITTED,
entity_type="order",
entity_id=order_id,
data={
"ticker": ticker,
"side": side,
"quantity": quantity,
"order_type": order_type,
"idempotency_key": idempotency_key,
"recommendation_id": recommendation_id,
"evaluation_id": evaluation_id,
},
actor="broker_service",
)
async def audit_order_filled(
pool: asyncpg.Pool,
order_id: str,
ticker: str,
side: str,
fill_quantity: float,
fill_price: float | None,
broker_order_id: str,
) -> str:
"""Record that an order was filled by the broker."""
return await record_audit_event(
pool,
event_type=AUDIT_ORDER_FILLED,
entity_type="order",
entity_id=order_id,
data={
"ticker": ticker,
"side": side,
"fill_quantity": fill_quantity,
"fill_price": fill_price,
"broker_order_id": broker_order_id,
},
actor="broker_service",
)
async def audit_order_rejected(
pool: asyncpg.Pool,
order_id: str,
ticker: str,
reason: str,
source: str = "broker",
) -> str:
"""Record that an order was rejected (by risk engine or broker)."""
return await record_audit_event(
pool,
event_type=AUDIT_ORDER_REJECTED,
entity_type="order",
entity_id=order_id,
data={
"ticker": ticker,
"reason": reason,
"rejection_source": source,
},
actor="broker_service",
)
async def audit_order_cancelled(
pool: asyncpg.Pool,
order_id: str,
ticker: str,
broker_order_id: str,
) -> str:
"""Record that an order was cancelled."""
return await record_audit_event(
pool,
event_type=AUDIT_ORDER_CANCELLED,
entity_type="order",
entity_id=order_id,
data={
"ticker": ticker,
"broker_order_id": broker_order_id,
},
actor="broker_service",
)
async def audit_duplicate_prevented(
pool: asyncpg.Pool,
order_id: str,
ticker: str,
idempotency_key: str,
detected_via: str,
) -> str:
"""Record that a duplicate order was prevented."""
return await record_audit_event(
pool,
event_type=AUDIT_ORDER_DUPLICATE,
entity_type="order",
entity_id=order_id,
data={
"ticker": ticker,
"idempotency_key": idempotency_key,
"detected_via": detected_via,
},
actor="broker_service",
)
async def audit_position_change(
pool: asyncpg.Pool,
order_id: str,
ticker: str,
side: str,
quantity_before: float,
quantity_after: float,
avg_entry_before: float,
avg_entry_after: float,
) -> str:
"""Record a position change resulting from a fill."""
if quantity_before == 0 and quantity_after > 0:
event_type = AUDIT_POSITION_OPENED
elif quantity_after == 0:
event_type = AUDIT_POSITION_CLOSED
else:
event_type = AUDIT_POSITION_UPDATED
return await record_audit_event(
pool,
event_type=event_type,
entity_type="position",
entity_id=order_id,
data={
"ticker": ticker,
"side": side,
"quantity_before": quantity_before,
"quantity_after": quantity_after,
"avg_entry_before": avg_entry_before,
"avg_entry_after": avg_entry_after,
},
actor="broker_service",
)
async def audit_approval_requested(
pool: asyncpg.Pool,
approval_id: str,
ticker: str,
side: str,
quantity: float,
estimated_value: float,
recommendation_id: str | None = None,
expires_at: str | None = None,
) -> str:
"""Record that an operator approval was requested for a live order."""
return await record_audit_event(
pool,
event_type=AUDIT_APPROVAL_REQUESTED,
entity_type="approval",
entity_id=approval_id,
data={
"ticker": ticker,
"side": side,
"quantity": quantity,
"estimated_value": estimated_value,
"recommendation_id": recommendation_id,
"expires_at": expires_at,
},
actor="broker_service",
)
async def audit_approval_reviewed(
pool: asyncpg.Pool,
approval_id: str,
ticker: str,
approved: bool,
reviewed_by: str = "operator",
review_note: str = "",
) -> str:
"""Record that an operator reviewed an approval request."""
event_type = AUDIT_APPROVAL_APPROVED if approved else AUDIT_APPROVAL_REJECTED
return await record_audit_event(
pool,
event_type=event_type,
entity_type="approval",
entity_id=approval_id,
data={
"ticker": ticker,
"approved": approved,
"reviewed_by": reviewed_by,
"review_note": review_note,
},
actor=reviewed_by,
)
async def audit_approval_expired(
pool: asyncpg.Pool,
approval_id: str,
ticker: str,
) -> str:
"""Record that an approval request expired without review."""
return await record_audit_event(
pool,
event_type=AUDIT_APPROVAL_EXPIRED,
entity_type="approval",
entity_id=approval_id,
data={"ticker": ticker},
actor="system",
)
async def audit_trading_mode_changed(
pool: asyncpg.Pool,
config_id: str,
old_mode: str,
new_mode: str,
actor: str = "operator",
) -> str:
"""Record a trading mode change."""
return await record_audit_event(
pool,
event_type=AUDIT_TRADING_MODE_CHANGED,
entity_type="risk_config",
entity_id=config_id,
data={
"old_mode": old_mode,
"new_mode": new_mode,
},
actor=actor,
)
# ---------------------------------------------------------------------------
# Query helpers for audit trail retrieval (Requirement 11.3)
# ---------------------------------------------------------------------------
_FETCH_AUDIT_TRAIL_FOR_ORDER = """
SELECT id, event_type, entity_type, entity_id, actor, data, created_at
FROM audit_events
WHERE entity_id = $1::uuid
OR data->>'recommendation_id' = $2
OR data->>'order_id' = $2
ORDER BY created_at ASC
"""
_FETCH_AUDIT_TRAIL_BY_ENTITY = """
SELECT id, event_type, entity_type, entity_id, actor, data, created_at
FROM audit_events
WHERE entity_type = $1 AND entity_id = $2::uuid
ORDER BY created_at ASC
"""
_FETCH_FULL_EXECUTION_TRAIL = """
SELECT id, event_type, entity_type, entity_id, actor, data, created_at
FROM audit_events
WHERE entity_id = $1::uuid
OR entity_id IN (
SELECT entity_id FROM audit_events
WHERE data->>'recommendation_id' = $2
)
ORDER BY created_at ASC
"""
async def get_order_audit_trail(
pool: asyncpg.Pool,
order_id: str,
recommendation_id: str | None = None,
) -> list[dict[str, Any]]:
"""Fetch the full audit trail for an order, including related recommendation and risk events.
Returns events ordered chronologically so the full decision chain
is visible: recommendation → risk → order → fill/reject.
"""
ref_id = recommendation_id or order_id
rows = await pool.fetch(_FETCH_AUDIT_TRAIL_FOR_ORDER, order_id, ref_id)
return [
{
"id": str(row["id"]),
"event_type": row["event_type"],
"entity_type": row["entity_type"],
"entity_id": str(row["entity_id"]),
"actor": row["actor"],
"data": row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]),
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
for row in rows
]
async def get_entity_audit_trail(
pool: asyncpg.Pool,
entity_type: str,
entity_id: str,
) -> list[dict[str, Any]]:
"""Fetch all audit events for a specific entity."""
rows = await pool.fetch(_FETCH_AUDIT_TRAIL_BY_ENTITY, entity_type, entity_id)
return [
{
"id": str(row["id"]),
"event_type": row["event_type"],
"entity_type": row["entity_type"],
"entity_id": str(row["entity_id"]),
"actor": row["actor"],
"data": row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]),
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
}
for row in rows
]
+108
View File
@@ -43,6 +43,10 @@ class OllamaConfig:
base_url: str = "http://localhost:11434"
model: str = "llama3.1:8b"
timeout: int = 120
max_retries: int = 2
retry_base_delay: float = 1.0
retry_max_delay: float = 10.0
retry_backoff_multiplier: float = 2.0
@dataclass
@@ -51,16 +55,82 @@ class TrinoConfig:
port: int = 8080
catalog: str = "lakehouse"
schema: str = "stonks"
iceberg_catalog: str = "iceberg"
@dataclass
class MarketDataConfig:
api_key: str = ""
base_url: str = "https://api.polygon.io"
provider: str = "polygon"
@dataclass
class BrokerConfig:
mode: str = "paper" # paper | live
provider: str = "alpaca"
api_key: Optional[str] = None
api_secret: Optional[str] = None
base_url: Optional[str] = None
@dataclass
class RetentionConfig:
"""Default retention periods (days) per bucket class.
These can be overridden per-bucket via the retention_policies DB table.
The cleanup_interval_hours controls how often the retention worker runs.
"""
raw_market_days: int = 90
raw_news_days: int = 180
raw_filings_days: int = 365
normalized_days: int = 180
llm_prompts_days: int = 365
llm_results_days: int = 365
lakehouse_days: int = 730
audit_days: int = 730
cleanup_interval_hours: int = 24
batch_size: int = 1000
# Map bucket names to RetentionConfig field names
BUCKET_RETENTION_FIELDS: dict[str, str] = {
"stonks-raw-market": "raw_market_days",
"stonks-raw-news": "raw_news_days",
"stonks-raw-filings": "raw_filings_days",
"stonks-normalized": "normalized_days",
"stonks-llm-prompts": "llm_prompts_days",
"stonks-llm-results": "llm_results_days",
"stonks-lakehouse": "lakehouse_days",
"stonks-audit": "audit_days",
}
@dataclass
class AlertingConfig:
"""Thresholds for operational alerting rules.
Requirements: 12.3
"""
# Source failure alerting
source_failure_threshold: int = 3 # consecutive failures before alert
source_failure_window_hours: int = 6 # lookback window
# Schema/extraction failure spike
schema_failure_rate_threshold: float = 0.3 # 30% failure rate triggers alert
schema_failure_window_hours: int = 1
# Analytical (lake publication) lag
lake_lag_threshold_minutes: int = 60 # minutes since last successful publish
# Broker issues
broker_error_threshold: int = 3 # consecutive broker errors
broker_error_window_hours: int = 1
# Evaluation interval
check_interval_seconds: int = 120
@dataclass
class AppConfig:
postgres: PostgresConfig = field(default_factory=PostgresConfig)
@@ -68,8 +138,12 @@ class AppConfig:
minio: MinioConfig = field(default_factory=MinioConfig)
ollama: OllamaConfig = field(default_factory=OllamaConfig)
trino: TrinoConfig = field(default_factory=TrinoConfig)
market_data: MarketDataConfig = field(default_factory=MarketDataConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig)
retention: RetentionConfig = field(default_factory=RetentionConfig)
alerting: AlertingConfig = field(default_factory=AlertingConfig)
log_level: str = "INFO"
json_logs: bool = True
def load_config() -> AppConfig:
@@ -98,18 +172,52 @@ def load_config() -> AppConfig:
base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"),
model=os.getenv("OLLAMA_MODEL", "llama3.1:8b"),
timeout=int(os.getenv("OLLAMA_TIMEOUT", "120")),
max_retries=int(os.getenv("OLLAMA_MAX_RETRIES", "2")),
retry_base_delay=float(os.getenv("OLLAMA_RETRY_BASE_DELAY", "1.0")),
retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")),
retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")),
),
trino=TrinoConfig(
host=os.getenv("TRINO_HOST", "localhost"),
port=int(os.getenv("TRINO_PORT", "8080")),
catalog=os.getenv("TRINO_CATALOG", "lakehouse"),
schema=os.getenv("TRINO_SCHEMA", "stonks"),
iceberg_catalog=os.getenv("TRINO_ICEBERG_CATALOG", "iceberg"),
),
market_data=MarketDataConfig(
api_key=os.getenv("MARKET_DATA_API_KEY", ""),
base_url=os.getenv("MARKET_DATA_BASE_URL", "https://api.polygon.io"),
provider=os.getenv("MARKET_DATA_PROVIDER", "polygon"),
),
broker=BrokerConfig(
mode=os.getenv("BROKER_MODE", "paper"),
provider=os.getenv("BROKER_PROVIDER", "alpaca"),
api_key=os.getenv("BROKER_API_KEY", None),
api_secret=os.getenv("BROKER_API_SECRET", None),
base_url=os.getenv("BROKER_BASE_URL", None),
),
retention=RetentionConfig(
raw_market_days=int(os.getenv("RETENTION_RAW_MARKET_DAYS", "90")),
raw_news_days=int(os.getenv("RETENTION_RAW_NEWS_DAYS", "180")),
raw_filings_days=int(os.getenv("RETENTION_RAW_FILINGS_DAYS", "365")),
normalized_days=int(os.getenv("RETENTION_NORMALIZED_DAYS", "180")),
llm_prompts_days=int(os.getenv("RETENTION_LLM_PROMPTS_DAYS", "365")),
llm_results_days=int(os.getenv("RETENTION_LLM_RESULTS_DAYS", "365")),
lakehouse_days=int(os.getenv("RETENTION_LAKEHOUSE_DAYS", "730")),
audit_days=int(os.getenv("RETENTION_AUDIT_DAYS", "730")),
cleanup_interval_hours=int(os.getenv("RETENTION_CLEANUP_INTERVAL_HOURS", "24")),
batch_size=int(os.getenv("RETENTION_BATCH_SIZE", "1000")),
),
alerting=AlertingConfig(
source_failure_threshold=int(os.getenv("ALERT_SOURCE_FAILURE_THRESHOLD", "3")),
source_failure_window_hours=int(os.getenv("ALERT_SOURCE_FAILURE_WINDOW_HOURS", "6")),
schema_failure_rate_threshold=float(os.getenv("ALERT_SCHEMA_FAILURE_RATE_THRESHOLD", "0.3")),
schema_failure_window_hours=int(os.getenv("ALERT_SCHEMA_FAILURE_WINDOW_HOURS", "1")),
lake_lag_threshold_minutes=int(os.getenv("ALERT_LAKE_LAG_THRESHOLD_MINUTES", "60")),
broker_error_threshold=int(os.getenv("ALERT_BROKER_ERROR_THRESHOLD", "3")),
broker_error_window_hours=int(os.getenv("ALERT_BROKER_ERROR_WINDOW_HOURS", "1")),
check_interval_seconds=int(os.getenv("ALERT_CHECK_INTERVAL_SECONDS", "120")),
),
log_level=os.getenv("LOG_LEVEL", "INFO"),
json_logs=os.getenv("JSON_LOGS", "true").lower() == "true",
)
+43
View File
@@ -0,0 +1,43 @@
"""Canonical URL normalization and content hashing utilities.
Provides consistent URL canonicalization and SHA-256 content hashing
across all ingestion adapters and pipeline stages.
Requirements: 3.2, 3.3
"""
import hashlib
from urllib.parse import parse_qsl, urlencode, urlparse
def normalize_url(url: str) -> str:
"""Canonical URL normalization.
- Lowercases scheme and host
- Strips fragments
- Strips trailing slashes from path (preserves root "/")
- Strips default ports (80, 443)
- Sorts query parameters for deterministic comparison
- Defaults scheme to https if missing
"""
parsed = urlparse(url)
scheme = (parsed.scheme or "https").lower()
netloc = (parsed.hostname or "").lower()
if parsed.port and parsed.port not in (80, 443):
netloc = f"{netloc}:{parsed.port}"
path = parsed.path.rstrip("/") or "/"
# Sort query params for deterministic ordering
query = urlencode(sorted(parse_qsl(parsed.query)))
normalized = f"{scheme}://{netloc}{path}"
if query:
normalized = f"{normalized}?{query}"
return normalized
def content_hash(data: bytes) -> str:
"""Compute a stable SHA-256 hex digest for raw content bytes."""
return hashlib.sha256(data).hexdigest()
def content_hash_str(text: str, encoding: str = "utf-8") -> str:
"""Compute a stable SHA-256 hex digest for a text string."""
return hashlib.sha256(text.encode(encoding)).hexdigest()
+134
View File
@@ -0,0 +1,134 @@
"""Dead-letter queue (DLQ) support and replay tooling.
When a worker fails to process a job after exhausting retries, the job
is pushed to a per-queue dead-letter list in Redis. Each DLQ entry
wraps the original payload with failure metadata (error message,
timestamp, attempt count) so operators can inspect and replay later.
Replay moves items from the DLQ back to the source queue for
reprocessing.
Requirements: 12.1 (observability), design section 8 (data flows)
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from services.shared.redis_keys import dlq_key, queue_key
logger = logging.getLogger(__name__)
# Default max attempts before a job is dead-lettered
DEFAULT_MAX_ATTEMPTS = 3
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def wrap_dlq_entry(
original_payload: dict[str, Any],
queue_name: str,
error: str,
attempt: int = 1,
worker: str = "",
) -> dict[str, Any]:
"""Wrap an original job payload with DLQ metadata."""
return {
"original_payload": original_payload,
"queue": queue_name,
"error": error,
"attempt": attempt,
"worker": worker,
"dead_lettered_at": _now_iso(),
}
async def send_to_dlq(
rds: aioredis.Redis,
queue_name: str,
original_payload: dict[str, Any],
error: str,
attempt: int = 1,
worker: str = "",
) -> None:
"""Push a failed job to the dead-letter queue for *queue_name*."""
entry = wrap_dlq_entry(original_payload, queue_name, error, attempt, worker)
await rds.rpush(dlq_key(queue_name), json.dumps(entry, default=str))
logger.warning(
"Dead-lettered job on %s after %d attempts: %s",
queue_name, attempt, error,
extra={"queue": queue_name, "attempt": attempt},
)
async def dlq_length(rds: aioredis.Redis, queue_name: str) -> int:
"""Return the number of items in the DLQ for *queue_name*."""
return await rds.llen(dlq_key(queue_name))
async def peek_dlq(
rds: aioredis.Redis,
queue_name: str,
start: int = 0,
count: int = 10,
) -> list[dict[str, Any]]:
"""Return DLQ entries without removing them (for inspection)."""
raw_items = await rds.lrange(dlq_key(queue_name), start, start + count - 1)
return [json.loads(item) for item in raw_items]
async def replay_one(rds: aioredis.Redis, queue_name: str) -> dict[str, Any] | None:
"""Pop the oldest DLQ entry and re-enqueue its original payload.
Returns the replayed DLQ entry, or None if the DLQ is empty.
"""
raw = await rds.lpop(dlq_key(queue_name))
if raw is None:
return None
entry = json.loads(raw)
original = entry.get("original_payload", entry)
await rds.rpush(queue_key(queue_name), json.dumps(original, default=str))
logger.info("Replayed 1 job from DLQ back to %s", queue_name)
return entry
async def replay_all(rds: aioredis.Redis, queue_name: str) -> int:
"""Replay every item in the DLQ back to the source queue.
Returns the number of items replayed.
"""
count = 0
while True:
raw = await rds.lpop(dlq_key(queue_name))
if raw is None:
break
entry = json.loads(raw)
original = entry.get("original_payload", entry)
await rds.rpush(queue_key(queue_name), json.dumps(original, default=str))
count += 1
if count:
logger.info("Replayed %d jobs from DLQ back to %s", count, queue_name)
return count
async def purge_dlq(rds: aioredis.Redis, queue_name: str) -> int:
"""Delete all items from the DLQ for *queue_name*. Returns count removed."""
key = dlq_key(queue_name)
length = await rds.llen(key)
if length:
await rds.delete(key)
return length
async def dlq_summary(rds: aioredis.Redis, queue_names: list[str]) -> dict[str, int]:
"""Return a mapping of queue_name -> DLQ depth for the given queues."""
result: dict[str, int] = {}
for name in queue_names:
result[name] = await rds.llen(dlq_key(name))
return result
+198
View File
@@ -0,0 +1,198 @@
"""Cross-source deduplication for articles and filings.
Detects duplicate documents across different source types (news_api,
filings_api, web_scrape) using a layered approach:
1. Redis fast-path: check content_hash and canonical_url markers for
recently-seen documents (TTL-bounded, cheap).
2. PostgreSQL fallback: query the documents table by canonical_url or
content_hash for durable cross-source matching.
When a duplicate is detected the caller receives the existing document_id
so it can link additional company mentions without re-inserting the document.
Requirements: 3.2, 3.3
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
import asyncpg
import redis.asyncio as aioredis
from services.shared.content import content_hash_str, normalize_url
from services.shared.redis_keys import DEDUPE_PREFIX
logger = logging.getLogger("dedupe")
# Redis TTL for dedupe markers (24 hours)
DEDUPE_TTL_SECONDS: int = 86400
def _url_dedupe_key(canonical_url: str) -> str:
"""Build a Redis key for URL-based deduplication."""
return f"{DEDUPE_PREFIX}:url:{content_hash_str(canonical_url)}"
def _hash_dedupe_key(content_hash: str) -> str:
"""Build a Redis key for content-hash-based deduplication."""
return f"{DEDUPE_PREFIX}:{content_hash}"
@dataclass
class DedupeResult:
"""Result of a deduplication check."""
is_duplicate: bool
existing_document_id: str | None = None
match_type: str | None = None # "content_hash" | "canonical_url" | None
async def check_duplicate(
pool: asyncpg.Pool,
rds: aioredis.Redis,
*,
content_hash: str,
url: str | None = None,
canonical_url: str | None = None,
) -> DedupeResult:
"""Check whether a document is a duplicate across all source types.
Checks in order of cost:
1. Redis content_hash marker (fast path)
2. Redis canonical_url marker (fast path)
3. PostgreSQL documents.content_hash (durable)
4. PostgreSQL documents.canonical_url (cross-source)
Returns a DedupeResult indicating whether the document already exists.
"""
# Resolve canonical URL if only raw URL provided
resolved_canonical = canonical_url or (normalize_url(url) if url else None)
# --- Redis fast path: content hash ---
if content_hash:
redis_key = _hash_dedupe_key(content_hash)
cached_id = await rds.get(redis_key)
if cached_id:
logger.debug("Dedupe hit (redis content_hash) for %s", content_hash[:16])
return DedupeResult(
is_duplicate=True,
existing_document_id=str(cached_id),
match_type="content_hash",
)
# --- Redis fast path: canonical URL ---
if resolved_canonical:
url_key = _url_dedupe_key(resolved_canonical)
cached_id = await rds.get(url_key)
if cached_id:
logger.debug("Dedupe hit (redis canonical_url) for %s", resolved_canonical[:60])
return DedupeResult(
is_duplicate=True,
existing_document_id=str(cached_id),
match_type="canonical_url",
)
# --- PostgreSQL fallback: content hash ---
if content_hash:
row = await pool.fetchrow(
"SELECT id FROM documents WHERE content_hash = $1 LIMIT 1",
content_hash,
)
if row:
doc_id = str(row["id"])
# Warm the Redis cache for future checks
await _set_dedupe_markers(rds, content_hash, resolved_canonical, doc_id)
logger.debug("Dedupe hit (pg content_hash) for %s", content_hash[:16])
return DedupeResult(
is_duplicate=True,
existing_document_id=doc_id,
match_type="content_hash",
)
# --- PostgreSQL fallback: canonical URL ---
if resolved_canonical:
row = await pool.fetchrow(
"SELECT id FROM documents WHERE canonical_url = $1 LIMIT 1",
resolved_canonical,
)
if row:
doc_id = str(row["id"])
await _set_dedupe_markers(rds, content_hash, resolved_canonical, doc_id)
logger.debug("Dedupe hit (pg canonical_url) for %s", resolved_canonical[:60])
return DedupeResult(
is_duplicate=True,
existing_document_id=doc_id,
match_type="canonical_url",
)
return DedupeResult(is_duplicate=False)
async def mark_as_seen(
rds: aioredis.Redis,
*,
content_hash: str,
canonical_url: str | None,
document_id: str,
) -> None:
"""Mark a newly-persisted document in Redis for fast future dedupe checks."""
await _set_dedupe_markers(rds, content_hash, canonical_url, document_id)
async def _set_dedupe_markers(
rds: aioredis.Redis,
content_hash: str | None,
canonical_url: str | None,
document_id: str,
) -> None:
"""Set Redis dedupe markers for both content hash and canonical URL."""
if content_hash:
await rds.set(
_hash_dedupe_key(content_hash), document_id, ex=DEDUPE_TTL_SECONDS
)
if canonical_url:
await rds.set(
_url_dedupe_key(canonical_url), document_id, ex=DEDUPE_TTL_SECONDS
)
async def dedupe_items(
pool: asyncpg.Pool,
rds: aioredis.Redis,
items: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Partition a list of ingestion items into new and duplicate groups.
Each item is expected to have at least one of:
- content_hash: SHA-256 of the raw content
- url / canonical_url: the document URL
Returns (new_items, duplicate_items).
"""
new_items: list[dict[str, Any]] = []
dup_items: list[dict[str, Any]] = []
for item in items:
item_hash = item.get("content_hash", "")
item_url = item.get("url") or item.get("link")
item_canonical = item.get("canonical_url")
result = await check_duplicate(
pool,
rds,
content_hash=item_hash,
url=item_url,
canonical_url=item_canonical,
)
if result.is_duplicate:
item["_dedupe_match_type"] = result.match_type
item["_dedupe_existing_id"] = result.existing_document_id
dup_items.append(item)
else:
new_items.append(item)
return new_items, dup_items
+224
View File
@@ -0,0 +1,224 @@
"""Structured logging and distributed tracing for all Stonks Oracle services.
Provides:
- JSON-formatted structured log output for machine-parseable log aggregation
- Trace context (trace_id, span_id, service) propagated through log records
- Context manager for creating trace spans within a service
- Helper to configure logging for any service worker or API
Requirements: 12.1
Design: Section 12 (Observability and Operations)
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from contextvars import ContextVar
from datetime import datetime, timezone
from typing import Any
# ---------------------------------------------------------------------------
# Trace context stored in contextvars for async-safe propagation
# ---------------------------------------------------------------------------
_trace_id: ContextVar[str] = ContextVar("trace_id", default="")
_span_id: ContextVar[str] = ContextVar("span_id", default="")
_service_name: ContextVar[str] = ContextVar("service_name", default="unknown")
def get_trace_id() -> str:
return _trace_id.get()
def get_span_id() -> str:
return _span_id.get()
def get_service_name() -> str:
return _service_name.get()
def set_trace_context(
trace_id: str | None = None,
span_id: str | None = None,
service: str | None = None,
) -> None:
"""Set trace context for the current async task / thread."""
if trace_id is not None:
_trace_id.set(trace_id)
if span_id is not None:
_span_id.set(span_id)
if service is not None:
_service_name.set(service)
def new_trace_id() -> str:
return uuid.uuid4().hex[:16]
def new_span_id() -> str:
return uuid.uuid4().hex[:8]
# ---------------------------------------------------------------------------
# Span context manager for tracing within a service
# ---------------------------------------------------------------------------
class Span:
"""Lightweight span for distributed tracing.
Usage::
with Span("process_document", ticker="AAPL") as span:
# ... do work ...
span.set_attribute("doc_count", 5)
On exit the span logs its duration and attributes as a structured event.
"""
def __init__(self, operation: str, **attributes: Any) -> None:
self.operation = operation
self.parent_span_id = get_span_id()
self.span_id = new_span_id()
self.trace_id = get_trace_id() or new_trace_id()
self.attributes: dict[str, Any] = dict(attributes)
self.start_time: float = 0.0
self.duration_ms: float = 0.0
self._token_trace: Any = None
self._token_span: Any = None
self._logger = logging.getLogger(get_service_name() or "tracing")
def set_attribute(self, key: str, value: Any) -> None:
self.attributes[key] = value
def __enter__(self) -> Span:
self.start_time = time.monotonic()
self._token_trace = _trace_id.set(self.trace_id)
self._token_span = _span_id.set(self.span_id)
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.duration_ms = (time.monotonic() - self.start_time) * 1000
status = "error" if exc_type else "ok"
self._logger.info(
"span.end",
extra={
"span_operation": self.operation,
"span_status": status,
"span_duration_ms": round(self.duration_ms, 2),
"span_parent_id": self.parent_span_id,
"span_attributes": self.attributes,
},
)
# Restore parent span context
if self._token_span is not None:
_span_id.reset(self._token_span)
if self._token_trace is not None:
_trace_id.reset(self._token_trace)
# ---------------------------------------------------------------------------
# JSON log formatter
# ---------------------------------------------------------------------------
class JSONFormatter(logging.Formatter):
"""Emit each log record as a single JSON line with trace context."""
def format(self, record: logging.LogRecord) -> str:
log_entry: dict[str, Any] = {
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"service": get_service_name(),
"trace_id": get_trace_id(),
"span_id": get_span_id(),
}
# Merge extra fields from Span or manual extra={} usage
for key in (
"span_operation", "span_status", "span_duration_ms",
"span_parent_id", "span_attributes",
"ticker", "document_id", "source_type", "job_id",
"duration_ms", "error", "count",
):
val = getattr(record, key, None)
if val is not None:
log_entry[key] = val
if record.exc_info and record.exc_info[1]:
log_entry["exception"] = self.formatException(record.exc_info)
return json.dumps(log_entry, default=str)
# ---------------------------------------------------------------------------
# Setup helper
# ---------------------------------------------------------------------------
def setup_logging(
service_name: str,
level: str = "INFO",
json_output: bool = True,
) -> None:
"""Configure structured logging for a service.
Call this once at service startup (before any log calls).
Args:
service_name: Identifies this service in log output (e.g. "ingestion_worker").
level: Log level string (DEBUG, INFO, WARNING, ERROR).
json_output: If True, emit JSON lines. If False, use a human-readable format.
"""
_service_name.set(service_name)
root = logging.getLogger()
root.setLevel(getattr(logging, level.upper(), logging.INFO))
# Remove existing handlers to avoid duplicate output
root.handlers.clear()
handler = logging.StreamHandler()
if json_output:
handler.setFormatter(JSONFormatter())
else:
handler.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s (%(service)s) "
"trace=%(trace_id)s span=%(span_id)s%(message)s",
defaults={"service": service_name, "trace_id": "", "span_id": ""},
))
root.addHandler(handler)
# ---------------------------------------------------------------------------
# Trace context propagation through job payloads
# ---------------------------------------------------------------------------
def inject_trace_context(payload: dict[str, Any]) -> dict[str, Any]:
"""Inject current trace context into a job payload dict.
Call this before enqueuing a job to Redis so the downstream
worker can continue the same trace.
"""
trace_id = get_trace_id()
if trace_id:
payload["_trace_id"] = trace_id
return payload
def extract_trace_context(payload: dict[str, Any]) -> None:
"""Extract and set trace context from an incoming job payload.
Call this at the start of job processing. If no trace context
is present, generates a new trace_id.
"""
trace_id = payload.get("_trace_id") or new_trace_id()
set_trace_context(trace_id=trace_id, span_id=new_span_id())
+696
View File
@@ -0,0 +1,696 @@
"""Metadata persistence for market payloads, documents, and broker events.
Persists structured metadata records to PostgreSQL for all ingested artifacts.
Each source type has its own persistence path:
- market_api → market_snapshots table
- news_api / filings_api / web_scrape → documents + document_company_mentions
- broker → order_events or market_snapshots (for position/account snapshots)
Requirements: 3.3, 3.4, 8.3, 9.2
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timedelta, timezone
from typing import Any
import asyncpg
from services.shared.content import content_hash_str, normalize_url
logger = logging.getLogger("metadata")
async def persist_market_snapshot(
pool: asyncpg.Pool,
*,
company_id: str | None,
ticker: str,
snapshot_type: str,
data: dict[str, Any],
source_provider: str,
storage_ref: str,
content_hash: str,
captured_at: datetime | None = None,
) -> str:
"""Persist a market data snapshot to PostgreSQL.
Returns the snapshot row UUID.
"""
ts = captured_at or datetime.now(timezone.utc)
row_id = await pool.fetchval(
"""INSERT INTO market_snapshots
(company_id, ticker, snapshot_type, data, source_provider,
captured_at, storage_ref, content_hash)
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7, $8)
RETURNING id""",
company_id,
ticker,
snapshot_type,
json.dumps(data),
source_provider,
ts,
storage_ref,
content_hash,
)
logger.debug("Persisted market snapshot %s for %s", row_id, ticker)
return str(row_id)
async def persist_document(
pool: asyncpg.Pool,
*,
document_type: str,
source_type: str,
publisher: str,
url: str | None,
canonical_url: str | None,
title: str,
published_at: datetime | None,
content_hash: str,
storage_ref: str,
language: str = "en",
) -> str | None:
"""Persist a document metadata record to PostgreSQL.
Returns the document row UUID, or None if a duplicate content_hash exists.
"""
exists = await pool.fetchval(
"SELECT 1 FROM documents WHERE content_hash = $1", content_hash
)
if exists:
return None
doc_id = await pool.fetchval(
"""INSERT INTO documents
(document_type, source_type, publisher, url, canonical_url,
title, published_at, content_hash, raw_storage_ref,
language, status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, 'ingested')
RETURNING id""",
document_type,
source_type,
publisher,
url,
canonical_url,
title,
published_at,
content_hash,
storage_ref,
language,
)
logger.debug("Persisted document %s (%s)", doc_id, title[:60] if title else "")
return str(doc_id)
async def update_document_parse_results(
pool: asyncpg.Pool,
*,
document_id: str,
normalized_storage_ref: str | None,
parser_output_ref: str | None,
parse_quality_score: float,
parse_confidence: str,
status: str,
) -> None:
"""Update a document row with parser output references and quality scores.
Called after the parsing stage to persist normalized text location,
structured parser output location, quality score, and confidence.
Requirements: 4.1, 4.3, 9.1
"""
await pool.execute(
"""UPDATE documents SET
normalized_storage_ref = $2,
parser_output_ref = $3,
parse_quality_score = $4,
parse_confidence = $5,
status = $6,
updated_at = NOW()
WHERE id = $1""",
document_id,
normalized_storage_ref,
parser_output_ref,
parse_quality_score,
parse_confidence,
status,
)
logger.debug(
"Updated document %s parse results: quality=%.2f confidence=%s status=%s",
document_id, parse_quality_score, parse_confidence, status,
)
async def persist_document_company_mention(
pool: asyncpg.Pool,
*,
document_id: str,
company_id: str,
ticker: str,
mention_type: str = "direct",
confidence: float = 1.0,
) -> str:
"""Link a document to a company via document_company_mentions.
Returns the mention row UUID.
"""
mention_id = await pool.fetchval(
"""INSERT INTO document_company_mentions
(document_id, company_id, ticker, mention_type, confidence)
VALUES ($1::uuid, $2::uuid, $3, $4, $5)
RETURNING id""",
document_id,
company_id,
ticker,
mention_type,
confidence,
)
return str(mention_id)
async def persist_broker_event(
pool: asyncpg.Pool,
*,
ticker: str,
event_type: str,
data: dict[str, Any],
source_provider: str,
storage_ref: str,
content_hash: str,
captured_at: datetime | None = None,
) -> str:
"""Persist a broker event snapshot to market_snapshots.
Broker position/account snapshots are stored as market_snapshots
with snapshot_type prefixed by 'broker_' (e.g. broker_positions,
broker_account, broker_orders).
Returns the snapshot row UUID.
"""
ts = captured_at or datetime.now(timezone.utc)
row_id = await pool.fetchval(
"""INSERT INTO market_snapshots
(ticker, snapshot_type, data, source_provider,
captured_at, storage_ref, content_hash)
VALUES ($1, $2, $3::jsonb, $4, $5, $6, $7)
RETURNING id""",
ticker,
f"broker_{event_type}",
json.dumps(data),
source_provider,
ts,
storage_ref,
content_hash,
)
logger.debug("Persisted broker event %s for %s", row_id, ticker)
return str(row_id)
def _resolve_document_type(source_type: str) -> str:
"""Map source_type to a document_type value."""
mapping = {
"news_api": "article",
"filings_api": "filing",
"web_scrape": "press_release",
}
return mapping.get(source_type, "article")
def _extract_publisher(item: dict[str, Any]) -> str:
"""Extract publisher name from an adapter item dict."""
if item.get("publisher"):
return str(item["publisher"])
source = item.get("source")
if isinstance(source, dict):
return source.get("name", "")
if source:
return str(source)
return ""
def _parse_published_at(item: dict[str, Any]) -> datetime | None:
"""Parse published_at from various adapter item formats."""
raw = item.get("publishedAt") or item.get("published_at")
if not raw:
return None
if isinstance(raw, datetime):
return raw
try:
return datetime.fromisoformat(str(raw).replace("Z", "+00:00"))
except (ValueError, TypeError):
return None
async def persist_ingestion_items(
pool: asyncpg.Pool,
*,
source_type: str,
ticker: str,
company_id: str | None,
items: list[dict[str, Any]],
storage_ref: str,
adapter_metadata: dict[str, Any],
content_hash: str,
) -> tuple[int, list[str]]:
"""Route ingestion items to the correct persistence path.
Returns (new_item_count, list_of_new_ids).
"""
if source_type == "market_api":
return await _persist_market_items(
pool,
ticker=ticker,
company_id=company_id,
items=items,
storage_ref=storage_ref,
provider=adapter_metadata.get("provider", "unknown"),
content_hash=content_hash,
)
if source_type == "broker":
return await _persist_broker_items(
pool,
ticker=ticker,
items=items,
storage_ref=storage_ref,
provider=adapter_metadata.get("provider", "unknown"),
endpoint=adapter_metadata.get("endpoint", "positions"),
content_hash=content_hash,
)
# Document types: news_api, filings_api, web_scrape
return await _persist_document_items(
pool,
source_type=source_type,
ticker=ticker,
company_id=company_id,
items=items,
storage_ref=storage_ref,
)
async def _persist_market_items(
pool: asyncpg.Pool,
*,
ticker: str,
company_id: str | None,
items: list[dict[str, Any]],
storage_ref: str,
provider: str,
content_hash: str,
) -> tuple[int, list[str]]:
"""Persist market data items as market_snapshots rows."""
ids: list[str] = []
for item in items:
item_hash = content_hash_str(json.dumps(item, sort_keys=True))
# Skip duplicates
exists = await pool.fetchval(
"SELECT 1 FROM market_snapshots WHERE content_hash = $1", item_hash
)
if exists:
continue
snapshot_type = _infer_market_snapshot_type(item)
row_id = await persist_market_snapshot(
pool,
company_id=company_id,
ticker=ticker,
snapshot_type=snapshot_type,
data=item,
source_provider=provider,
storage_ref=storage_ref,
content_hash=item_hash,
)
ids.append(row_id)
return len(ids), ids
def _infer_market_snapshot_type(item: dict[str, Any]) -> str:
"""Infer snapshot_type from market data item fields."""
# Polygon aggregate bars have 'o', 'h', 'l', 'c' fields
if all(k in item for k in ("o", "h", "l", "c")):
return "bar"
# Ticker details have 'market_cap' or 'sic_code'
if "market_cap" in item or "sic_code" in item:
return "ticker_details"
# Quote snapshots
if "ask" in item or "bid" in item:
return "quote"
return "snapshot"
async def _persist_broker_items(
pool: asyncpg.Pool,
*,
ticker: str,
items: list[dict[str, Any]],
storage_ref: str,
provider: str,
endpoint: str,
content_hash: str,
) -> tuple[int, list[str]]:
"""Persist broker fetch items as market_snapshots with broker_ prefix."""
ids: list[str] = []
for item in items:
item_hash = content_hash_str(json.dumps(item, sort_keys=True))
exists = await pool.fetchval(
"SELECT 1 FROM market_snapshots WHERE content_hash = $1", item_hash
)
if exists:
continue
row_id = await persist_broker_event(
pool,
ticker=ticker,
event_type=endpoint,
data=item,
source_provider=provider,
storage_ref=storage_ref,
content_hash=item_hash,
)
ids.append(row_id)
return len(ids), ids
async def _persist_document_items(
pool: asyncpg.Pool,
*,
source_type: str,
ticker: str,
company_id: str | None,
items: list[dict[str, Any]],
storage_ref: str,
) -> tuple[int, list[str]]:
"""Persist document items (news, filings, web scrape) to documents table."""
doc_type = _resolve_document_type(source_type)
ids: list[str] = []
for item in items:
item_hash = item.get("content_hash") or content_hash_str(
json.dumps(item, sort_keys=True)
)
title = item.get("title", item.get("name", ""))
url = item.get("url", item.get("link", ""))
canonical_url = item.get("canonical_url") or (
normalize_url(url) if url else None
)
published_at = _parse_published_at(item)
publisher = _extract_publisher(item)
doc_id = await persist_document(
pool,
document_type=doc_type,
source_type=source_type,
publisher=publisher,
url=url or None,
canonical_url=canonical_url,
title=title,
published_at=published_at,
content_hash=item_hash,
storage_ref=storage_ref,
)
if doc_id is None:
continue
# Link document to company if we have a company_id
if company_id:
await persist_document_company_mention(
pool,
document_id=doc_id,
company_id=company_id,
ticker=ticker,
)
ids.append(doc_id)
return len(ids), ids
# --- Retry and failure tracking (Requirement 3.4) ---
# Backoff constants — match scheduler defaults for consistency
RETRY_BACKOFF_BASE: int = 60
RETRY_BACKOFF_MAX: int = 3600
RETRY_MAX_COUNT: int = 10
def compute_next_retry_at(
retry_count: int,
now: datetime | None = None,
base: int = RETRY_BACKOFF_BASE,
cap: int = RETRY_BACKOFF_MAX,
) -> datetime:
"""Compute the next eligible retry time using exponential backoff.
Args:
retry_count: Current retry count (before incrementing).
now: Reference timestamp (defaults to UTC now).
base: Base delay in seconds.
cap: Maximum delay in seconds.
Returns:
Datetime of the next eligible retry.
"""
ts = now or datetime.now(timezone.utc)
delay = min(base * (2 ** min(retry_count, 8)), cap)
return ts + timedelta(seconds=delay)
async def get_source_retry_count(
pool: asyncpg.Pool,
source_id: str,
) -> int:
"""Return the retry count from the most recent failed run for a source.
If the last run succeeded or no runs exist, returns 0.
"""
row = await pool.fetchrow(
"""SELECT status, retry_count
FROM ingestion_runs
WHERE source_id = $1::uuid
ORDER BY started_at DESC
LIMIT 1""",
source_id,
)
if row and row["status"] == "failed":
return row["retry_count"] or 0
return 0
async def record_retrieval_failure(
pool: asyncpg.Pool,
run_id: str,
source_id: str,
error_message: str,
retry_count: int | None = None,
now: datetime | None = None,
) -> dict[str, Any]:
"""Record a source retrieval failure with retry policy state.
Updates the ingestion_runs row with:
- error_message: the failure reason
- retry_count: incremented from the previous failed run (or provided)
- next_retry_at: computed via exponential backoff
- status: 'failed'
If retry_count is not provided, it is looked up from the most recent
failed run for the same source and incremented.
Returns a dict with the recorded retry state for observability.
Requirement 3.4
"""
ts = now or datetime.now(timezone.utc)
if retry_count is None:
prev_count = await get_source_retry_count(pool, source_id)
retry_count = prev_count + 1
else:
retry_count = retry_count + 1
next_retry = compute_next_retry_at(retry_count - 1, now=ts)
exhausted = retry_count >= RETRY_MAX_COUNT
await pool.execute(
"""UPDATE ingestion_runs
SET status = 'failed',
error_message = $2,
retry_count = $3,
next_retry_at = $4,
completed_at = $5
WHERE id = $1""",
run_id,
error_message,
retry_count,
next_retry,
ts,
)
state = {
"run_id": run_id,
"source_id": source_id,
"retry_count": retry_count,
"next_retry_at": next_retry.isoformat(),
"exhausted": exhausted,
"error_message": error_message,
}
if exhausted:
logger.warning(
"Source %s exhausted retries (%d/%d): %s",
source_id, retry_count, RETRY_MAX_COUNT, error_message,
)
else:
logger.info(
"Source %s failed (retry %d/%d), next retry at %s: %s",
source_id, retry_count, RETRY_MAX_COUNT,
next_retry.isoformat(), error_message,
)
return state
async def persist_document_intelligence(
pool: asyncpg.Pool,
*,
document_id: str,
summary: str,
macro_themes: list[str],
novelty_score: float,
source_credibility: float,
extraction_warnings: list[str],
confidence: float,
model_provider: str,
model_name: str,
prompt_version: str,
schema_version: str,
raw_output_ref: str | None = None,
prompt_ref: str | None = None,
validation_status: str = "valid",
validation_errors: list[str] | None = None,
retry_count: int = 0,
) -> str:
"""Persist a document intelligence record to PostgreSQL.
Returns the intelligence row UUID.
Requirements: 5.3, 5.4, 9.2
"""
intel_id = await pool.fetchval(
"""INSERT INTO document_intelligence
(document_id, summary, macro_themes, novelty_score,
source_credibility, extraction_warnings, confidence,
model_provider, model_name, prompt_version, schema_version,
raw_output_ref, prompt_ref, validation_status,
validation_errors, retry_count)
VALUES ($1::uuid, $2, $3::jsonb, $4, $5, $6::jsonb, $7,
$8, $9, $10, $11, $12, $13, $14, $15::jsonb, $16)
RETURNING id""",
document_id,
summary,
json.dumps(macro_themes),
novelty_score,
source_credibility,
json.dumps(extraction_warnings),
confidence,
model_provider,
model_name,
prompt_version,
schema_version,
raw_output_ref,
prompt_ref,
validation_status,
json.dumps(validation_errors or []),
retry_count,
)
logger.debug("Persisted document intelligence %s for doc %s", intel_id, document_id)
return str(intel_id)
async def persist_document_impact(
pool: asyncpg.Pool,
*,
intelligence_id: str,
company_id: str,
ticker: str,
relevance: float,
sentiment: str,
impact_score: float,
impact_horizon: str,
catalyst_type: str,
key_facts: list[str],
risks: list[str],
evidence_spans: list[str],
) -> str:
"""Persist a per-company impact record linked to a document intelligence row.
Returns the impact record UUID.
Requirements: 5.3, 5.5, 9.2
"""
impact_id = await pool.fetchval(
"""INSERT INTO document_impact_records
(intelligence_id, company_id, ticker, relevance, sentiment,
impact_score, impact_horizon, catalyst_type,
key_facts, risks, evidence_spans)
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6, $7, $8,
$9::jsonb, $10::jsonb, $11::jsonb)
RETURNING id""",
intelligence_id,
company_id,
ticker,
relevance,
sentiment,
impact_score,
impact_horizon,
catalyst_type,
json.dumps(key_facts),
json.dumps(risks),
json.dumps(evidence_spans),
)
logger.debug("Persisted impact record %s for %s", impact_id, ticker)
return str(impact_id)
async def update_document_status(
pool: asyncpg.Pool,
*,
document_id: str,
status: str,
) -> None:
"""Update the status field on a document row.
Used to advance documents through the pipeline: ingested → parsed → extracted → failed.
Requirements: 5.4
"""
await pool.execute(
"""UPDATE documents SET status = $2, updated_at = NOW() WHERE id = $1::uuid""",
document_id,
status,
)
logger.debug("Updated document %s status to %s", document_id, status)
async def reset_source_retry_state(
pool: asyncpg.Pool,
source_id: str,
) -> None:
"""Reset retry state for a source after a successful run.
Sets retry_count=0 and next_retry_at=NULL on the most recent run.
Called after a successful ingestion to clear any accumulated backoff.
"""
await pool.execute(
"""UPDATE ingestion_runs
SET retry_count = 0, next_retry_at = NULL
WHERE id = (
SELECT id FROM ingestion_runs
WHERE source_id = $1::uuid
ORDER BY started_at DESC
LIMIT 1
)""",
source_id,
)
+317
View File
@@ -0,0 +1,317 @@
"""Prometheus metrics for all Stonks Oracle pipeline stages.
Provides counters, histograms, and gauges covering:
- Ingestion: items fetched, new items, errors, adapter latency
- Parsing: documents parsed, quality scores, low-quality flags
- Extraction: attempts, successes, failures, latency, confidence, retries
- Aggregation: trend windows computed, signal counts, contradiction scores
- Lake publication: facts published per table, write latency
- Trading: orders submitted, rejected, filled, risk evaluations
Requirements: 12.1, 12.2
Design: Section 12 (Observability and Operations)
"""
from __future__ import annotations
from prometheus_client import Counter, Gauge, Histogram, Info
# ---------------------------------------------------------------------------
# Service info
# ---------------------------------------------------------------------------
SERVICE_INFO = Info("stonks_oracle", "Stonks Oracle service metadata")
# ---------------------------------------------------------------------------
# Ingestion metrics
# ---------------------------------------------------------------------------
INGESTION_JOBS_TOTAL = Counter(
"stonks_ingestion_jobs_total",
"Total ingestion jobs processed",
["source_type", "status"],
)
INGESTION_ITEMS_FETCHED = Counter(
"stonks_ingestion_items_fetched_total",
"Total items fetched from external sources",
["source_type"],
)
INGESTION_ITEMS_NEW = Counter(
"stonks_ingestion_items_new_total",
"New (non-duplicate) items ingested",
["source_type"],
)
INGESTION_ITEMS_DEDUPED = Counter(
"stonks_ingestion_items_deduped_total",
"Items skipped due to deduplication",
["source_type"],
)
INGESTION_ERRORS = Counter(
"stonks_ingestion_errors_total",
"Ingestion errors by source type",
["source_type"],
)
INGESTION_ADAPTER_DURATION = Histogram(
"stonks_ingestion_adapter_duration_seconds",
"Adapter fetch latency in seconds",
["source_type"],
buckets=(0.1, 0.5, 1, 2, 5, 10, 30, 60),
)
# ---------------------------------------------------------------------------
# Parsing metrics
# ---------------------------------------------------------------------------
PARSE_JOBS_TOTAL = Counter(
"stonks_parse_jobs_total",
"Total parse jobs processed",
["status"],
)
PARSE_QUALITY_SCORE = Histogram(
"stonks_parse_quality_score",
"Distribution of parser quality scores",
buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
)
PARSE_LOW_QUALITY_TOTAL = Counter(
"stonks_parse_low_quality_total",
"Documents flagged as low quality by the parser",
)
PARSE_DURATION = Histogram(
"stonks_parse_duration_seconds",
"Parse job duration in seconds",
buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10),
)
# ---------------------------------------------------------------------------
# Extraction metrics
# ---------------------------------------------------------------------------
EXTRACTION_JOBS_TOTAL = Counter(
"stonks_extraction_jobs_total",
"Total extraction jobs processed",
["status"],
)
EXTRACTION_ATTEMPTS = Counter(
"stonks_extraction_attempts_total",
"Total Ollama extraction attempts (including retries)",
)
EXTRACTION_RETRIES = Counter(
"stonks_extraction_retries_total",
"Extraction retry count",
)
EXTRACTION_DURATION = Histogram(
"stonks_extraction_duration_seconds",
"Extraction total duration in seconds",
buckets=(1, 2, 5, 10, 20, 30, 60, 120),
)
EXTRACTION_CONFIDENCE = Histogram(
"stonks_extraction_confidence",
"Distribution of extraction confidence scores",
buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
)
EXTRACTION_VALIDATION_ERRORS = Counter(
"stonks_extraction_validation_errors_total",
"Total validation errors across extractions",
)
EXTRACTION_TOKEN_ESTIMATE = Counter(
"stonks_extraction_tokens_total",
"Estimated token usage",
["direction"],
)
# ---------------------------------------------------------------------------
# Aggregation metrics
# ---------------------------------------------------------------------------
AGGREGATION_WINDOWS_COMPUTED = Counter(
"stonks_aggregation_windows_total",
"Trend windows computed",
["window"],
)
AGGREGATION_SIGNALS_PROCESSED = Counter(
"stonks_aggregation_signals_total",
"Signals processed during aggregation",
["window"],
)
AGGREGATION_CONTRADICTION_SCORE = Histogram(
"stonks_aggregation_contradiction_score",
"Distribution of contradiction scores in trend windows",
buckets=(0.0, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0),
)
AGGREGATION_DURATION = Histogram(
"stonks_aggregation_duration_seconds",
"Aggregation job duration in seconds",
["window"],
buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10),
)
# ---------------------------------------------------------------------------
# Recommendation metrics
# ---------------------------------------------------------------------------
RECOMMENDATION_GENERATED = Counter(
"stonks_recommendations_total",
"Recommendations generated",
["action", "mode"],
)
RECOMMENDATION_SUPPRESSED = Counter(
"stonks_recommendations_suppressed_total",
"Recommendations suppressed due to low data quality",
)
RECOMMENDATION_CONFIDENCE = Histogram(
"stonks_recommendation_confidence",
"Distribution of recommendation confidence scores",
buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0),
)
# ---------------------------------------------------------------------------
# Lake publication metrics
# ---------------------------------------------------------------------------
LAKE_FACTS_PUBLISHED = Counter(
"stonks_lake_facts_published_total",
"Analytical facts published to the lakehouse",
["table_name"],
)
LAKE_PUBLISH_DURATION = Histogram(
"stonks_lake_publish_duration_seconds",
"Lake publication write latency in seconds",
["table_name"],
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5),
)
LAKE_PUBLISH_ERRORS = Counter(
"stonks_lake_publish_errors_total",
"Lake publication errors",
["table_name"],
)
LAKE_PUBLISH_BYTES = Counter(
"stonks_lake_publish_bytes_total",
"Total bytes written to the lakehouse",
["table_name"],
)
# ---------------------------------------------------------------------------
# Trading / broker metrics
# ---------------------------------------------------------------------------
ORDERS_SUBMITTED = Counter(
"stonks_orders_submitted_total",
"Orders submitted to broker",
["side", "order_type", "mode"],
)
ORDERS_REJECTED = Counter(
"stonks_orders_rejected_total",
"Orders rejected before broker submission",
["reason_category"],
)
ORDERS_FILLED = Counter(
"stonks_orders_filled_total",
"Orders filled by broker",
["side"],
)
ORDERS_DUPLICATES_PREVENTED = Counter(
"stonks_orders_duplicates_prevented_total",
"Duplicate orders prevented by idempotency checks",
["detected_via"],
)
RISK_EVALUATIONS_TOTAL = Counter(
"stonks_risk_evaluations_total",
"Risk evaluations performed",
["result"],
)
RISK_CHECK_FAILURES = Counter(
"stonks_risk_check_failures_total",
"Individual risk check failures",
["check_name"],
)
POSITIONS_SYNCED = Counter(
"stonks_positions_synced_total",
"Position sync operations completed",
)
# ---------------------------------------------------------------------------
# Active gauges
# ---------------------------------------------------------------------------
ACTIVE_JOBS = Gauge(
"stonks_active_jobs",
"Currently processing jobs by stage",
["stage"],
)
# ---------------------------------------------------------------------------
# Alerting metrics
# ---------------------------------------------------------------------------
ALERTS_FIRED = Counter(
"stonks_alerts_fired_total",
"Total alerts fired by rule",
["rule", "severity"],
)
ALERTS_RESOLVED = Counter(
"stonks_alerts_resolved_total",
"Total alerts resolved by rule",
["rule"],
)
ALERT_CHECK_DURATION = Histogram(
"stonks_alert_check_duration_seconds",
"Duration of alert evaluation cycle",
buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5),
)
ALERT_ACTIVE = Gauge(
"stonks_alert_active",
"Whether an alert rule is currently firing (1) or resolved (0)",
["rule"],
)
# ---------------------------------------------------------------------------
# Dead-letter queue metrics
# ---------------------------------------------------------------------------
DLQ_ITEMS_TOTAL = Counter(
"stonks_dlq_items_total",
"Jobs sent to dead-letter queues",
["queue"],
)
DLQ_REPLAYED_TOTAL = Counter(
"stonks_dlq_replayed_total",
"Jobs replayed from dead-letter queues",
["queue"],
)
DLQ_DEPTH = Gauge(
"stonks_dlq_depth",
"Current dead-letter queue depth",
["queue"],
)
+10
View File
@@ -46,6 +46,15 @@ def retry_key(job_id: str) -> str:
return f"{RETRY_PREFIX}:{job_id}"
# Dead-letter queues
DLQ_PREFIX = f"{PREFIX}:dlq"
def dlq_key(queue_name: str) -> str:
"""Return the dead-letter queue key for a given source queue."""
return f"{DLQ_PREFIX}:{queue_name}"
# --- Queue names ---
QUEUE_INGESTION = "ingestion"
QUEUE_PARSING = "parsing"
@@ -54,3 +63,4 @@ QUEUE_AGGREGATION = "aggregation"
QUEUE_RECOMMENDATION = "recommendation"
QUEUE_LAKE_PUBLISH = "lake_publish"
QUEUE_TRADE = "trade"
QUEUE_BROKER = "broker_orders"
+306
View File
@@ -0,0 +1,306 @@
"""Data retention and lifecycle controls for raw and derived artifacts.
Provides configurable per-bucket retention policies, expired object cleanup
from MinIO, and expired metadata cleanup from PostgreSQL.
Requirements: N3 (preserve source metadata, access policy, and retention policy)
Design ref: Section 5.2 (MinIO bucket layout), Section 10 (Reliability and Safety)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
import asyncpg
from minio import Minio
from services.shared.config import BUCKET_RETENTION_FIELDS, RetentionConfig
from services.shared.storage import ALL_BUCKETS
logger = logging.getLogger("retention")
@dataclass
class RetentionPolicy:
"""Resolved retention policy for a single bucket."""
bucket_name: str
retention_days: int
archive_before_delete: bool = False
@dataclass
class CleanupResult:
"""Result of a single bucket cleanup run."""
bucket_name: str
objects_scanned: int = 0
objects_deleted: int = 0
bytes_freed: int = 0
db_rows_deleted: int = 0
def default_retention_days(bucket: str, config: RetentionConfig) -> int:
"""Get the default retention days for a bucket from config."""
field_name = BUCKET_RETENTION_FIELDS.get(bucket)
if field_name:
return getattr(config, field_name, 365)
return 365
def resolve_policies(config: RetentionConfig) -> list[RetentionPolicy]:
"""Build retention policies for all known buckets from config defaults."""
return [
RetentionPolicy(
bucket_name=bucket,
retention_days=default_retention_days(bucket, config),
)
for bucket in ALL_BUCKETS
]
async def load_db_policies(pool: asyncpg.Pool) -> dict[str, RetentionPolicy]:
"""Load retention policy overrides from the database.
Returns a dict keyed by bucket_name. DB policies take precedence
over config defaults when active.
"""
rows = await pool.fetch(
"""SELECT bucket_name, retention_days, archive_before_delete
FROM retention_policies
WHERE active = TRUE AND artifact_class = 'default'"""
)
return {
row["bucket_name"]: RetentionPolicy(
bucket_name=row["bucket_name"],
retention_days=row["retention_days"],
archive_before_delete=row["archive_before_delete"],
)
for row in rows
}
def merge_policies(
config_policies: list[RetentionPolicy],
db_policies: dict[str, RetentionPolicy],
) -> list[RetentionPolicy]:
"""Merge config defaults with DB overrides. DB wins on conflict."""
merged: list[RetentionPolicy] = []
for policy in config_policies:
if policy.bucket_name in db_policies:
merged.append(db_policies[policy.bucket_name])
else:
merged.append(policy)
return merged
def cutoff_date(retention_days: int, now: datetime | None = None) -> datetime:
"""Calculate the cutoff datetime. Objects older than this are expired."""
ref = now or datetime.now(timezone.utc)
return ref - timedelta(days=retention_days)
def list_expired_objects(
client: Minio,
bucket: str,
retention_days: int,
batch_size: int = 1000,
now: datetime | None = None,
) -> list[str]:
"""List object names in a bucket that are older than the retention cutoff.
Uses the object's last_modified timestamp from MinIO metadata.
Returns at most batch_size object names.
"""
cutoff = cutoff_date(retention_days, now)
expired: list[str] = []
try:
objects = client.list_objects(bucket, recursive=True)
for obj in objects:
if obj.last_modified and obj.last_modified < cutoff:
if obj.object_name:
expired.append(obj.object_name)
if len(expired) >= batch_size:
break
except Exception:
logger.exception("Error listing objects in bucket %s", bucket)
return expired
def delete_expired_objects(
client: Minio,
bucket: str,
object_names: list[str],
) -> int:
"""Delete a list of objects from a MinIO bucket.
Returns the count of successfully deleted objects.
"""
deleted = 0
for name in object_names:
try:
client.remove_object(bucket, name)
deleted += 1
except Exception:
logger.warning("Failed to delete %s/%s", bucket, name, exc_info=True)
return deleted
def cleanup_bucket(
client: Minio,
policy: RetentionPolicy,
batch_size: int = 1000,
now: datetime | None = None,
) -> CleanupResult:
"""Run retention cleanup for a single bucket.
Lists expired objects and deletes them in batches.
Returns a CleanupResult with counts.
"""
result = CleanupResult(bucket_name=policy.bucket_name)
expired = list_expired_objects(
client, policy.bucket_name, policy.retention_days,
batch_size=batch_size, now=now,
)
result.objects_scanned = len(expired)
if expired:
result.objects_deleted = delete_expired_objects(client, policy.bucket_name, expired)
logger.info(
"Bucket %s: scanned=%d deleted=%d (retention=%dd)",
policy.bucket_name, result.objects_scanned,
result.objects_deleted, policy.retention_days,
)
else:
logger.debug("Bucket %s: no expired objects (retention=%dd)",
policy.bucket_name, policy.retention_days)
return result
# --- PostgreSQL metadata cleanup ---
# Tables with a created_at or retrieved_at column that should be cleaned up
# when the corresponding MinIO artifacts are expired.
DB_CLEANUP_QUERIES: list[tuple[str, str]] = [
(
"ingestion_runs",
"DELETE FROM ingestion_runs WHERE started_at < $1",
),
(
"market_snapshots",
"DELETE FROM market_snapshots WHERE captured_at < $1",
),
]
async def cleanup_expired_db_records(
pool: asyncpg.Pool,
retention_days: int,
now: datetime | None = None,
) -> int:
"""Delete expired operational metadata from PostgreSQL.
Uses the shortest raw retention period to clean up ingestion tracking
and market snapshot records that are past their useful life.
Returns total rows deleted.
"""
cutoff = cutoff_date(retention_days, now)
total_deleted = 0
async with pool.acquire() as conn:
for table_name, query in DB_CLEANUP_QUERIES:
try:
result = await conn.execute(query, cutoff)
# asyncpg returns "DELETE N"
count = int(result.split()[-1]) if result else 0
total_deleted += count
if count > 0:
logger.info("Cleaned %d expired rows from %s (cutoff=%s)",
count, table_name, cutoff.isoformat())
except Exception:
logger.exception("Error cleaning table %s", table_name)
return total_deleted
async def record_retention_run(
pool: asyncpg.Pool,
bucket_name: str,
result: CleanupResult,
status: str = "completed",
error_message: str | None = None,
) -> None:
"""Record a retention cleanup run in the retention_runs table."""
await pool.execute(
"""INSERT INTO retention_runs
(bucket_name, objects_scanned, objects_deleted, bytes_freed,
db_rows_deleted, completed_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, NOW(), $6, $7)""",
bucket_name,
result.objects_scanned,
result.objects_deleted,
result.bytes_freed,
result.db_rows_deleted,
status,
error_message,
)
async def run_retention_cleanup(
minio_client: Minio,
pool: asyncpg.Pool,
config: RetentionConfig,
now: datetime | None = None,
) -> list[CleanupResult]:
"""Run the full retention cleanup cycle.
1. Resolve policies from config defaults + DB overrides
2. Clean up expired MinIO objects per bucket
3. Clean up expired PostgreSQL metadata
4. Record each run for observability
Returns a list of CleanupResult for each bucket processed.
"""
# Resolve policies
config_policies = resolve_policies(config)
try:
db_policies = await load_db_policies(pool)
except Exception:
logger.warning("Could not load DB retention policies, using config defaults")
db_policies = {}
policies = merge_policies(config_policies, db_policies)
results: list[CleanupResult] = []
# Clean up MinIO objects per bucket
for policy in policies:
try:
result = cleanup_bucket(
minio_client, policy,
batch_size=config.batch_size, now=now,
)
results.append(result)
await record_retention_run(pool, policy.bucket_name, result)
except Exception:
logger.exception("Retention cleanup failed for bucket %s", policy.bucket_name)
empty = CleanupResult(bucket_name=policy.bucket_name)
await record_retention_run(
pool, policy.bucket_name, empty,
status="failed", error_message="See logs",
)
results.append(empty)
# Clean up expired DB records using the shortest raw retention period
min_retention = min(p.retention_days for p in policies)
try:
db_deleted = await cleanup_expired_db_records(pool, min_retention, now=now)
if db_deleted > 0:
logger.info("Total DB rows cleaned: %d", db_deleted)
except Exception:
logger.exception("DB retention cleanup failed")
return results
+37
View File
@@ -108,6 +108,41 @@ class DocumentIntelligence(BaseModel):
# --- Trend Summary ---
class MarketContext(BaseModel):
"""Recent market data features for a symbol, used to enrich aggregation."""
ticker: str = ""
price_change_pct: Optional[float] = None # % change over the window
avg_volume: Optional[float] = None # average daily volume
volume_change_pct: Optional[float] = None # volume vs prior period
volatility: Optional[float] = None # intra-window price std dev
latest_close: Optional[float] = None
latest_bar_at: Optional[datetime] = None
bars_available: int = 0
@property
def has_data(self) -> bool:
return self.bars_available > 0
class DisagreementDetail(BaseModel):
"""Represents an explicit disagreement between document signals.
Rather than collapsing contradictory signals into a single score,
this captures the nature of the disagreement so downstream consumers
can inspect *why* signals conflict.
Requirements: 6.4
"""
dimension: str = "" # e.g. "sentiment", "catalyst", "impact_horizon"
positive_doc_ids: List[str] = Field(default_factory=list)
negative_doc_ids: List[str] = Field(default_factory=list)
positive_weight: float = 0.0
negative_weight: float = 0.0
description: str = ""
class TrendSummary(BaseModel):
entity_type: str = "company"
entity_id: str = ""
@@ -120,6 +155,8 @@ class TrendSummary(BaseModel):
dominant_catalysts: List[str] = Field(default_factory=list)
material_risks: List[str] = Field(default_factory=list)
contradiction_score: float = Field(ge=0, le=1, default=0.0)
disagreement_details: List[DisagreementDetail] = Field(default_factory=list)
market_context: Optional[MarketContext] = None
generated_at: datetime = Field(default_factory=datetime.utcnow)
+352
View File
@@ -0,0 +1,352 @@
"""Raw artifact upload to MinIO.
Provides a reusable storage layer for uploading raw artifacts (API payloads,
HTML, normalized text, model outputs) to MinIO with consistent path conventions,
bucket management, and content-type handling.
Bucket layout follows the design spec:
- stonks-raw-market — raw market API payloads
- stonks-raw-news — raw news API payloads and article HTML
- stonks-raw-filings — raw filings and issuer event payloads
- stonks-normalized — cleaned text and parser outputs
- stonks-llm-prompts — prompts and schemas used
- stonks-llm-results — raw model outputs and validation reports
- stonks-lakehouse — partitioned analytical datasets and table metadata
- stonks-audit — execution traces and exported reports
Object path pattern:
/{stage}/{symbol}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_type}.{ext}
Requirements: 3.1, 3.2, 3.3, 9.1
"""
import io
import logging
from datetime import datetime, timezone
from typing import Mapping
from minio import Minio
from minio.error import S3Error
logger = logging.getLogger("storage")
# All known buckets the platform uses
ALL_BUCKETS = [
"stonks-raw-market",
"stonks-raw-news",
"stonks-raw-filings",
"stonks-normalized",
"stonks-llm-prompts",
"stonks-llm-results",
"stonks-lakehouse",
"stonks-audit",
]
# Map source_type to the correct raw bucket
SOURCE_BUCKET_MAP: dict[str, str] = {
"market_api": "stonks-raw-market",
"news_api": "stonks-raw-news",
"filings_api": "stonks-raw-filings",
"web_scrape": "stonks-raw-news",
"broker": "stonks-raw-market",
}
# Map artifact type to content type and file extension
ARTIFACT_CONTENT_TYPES: dict[str, tuple[str, str]] = {
"raw_json": ("application/json", "json"),
"raw_html": ("text/html", "html"),
"raw_text": ("text/plain", "txt"),
"raw_payload": ("application/octet-stream", "bin"),
}
def bucket_for_source(source_type: str) -> str:
"""Return the MinIO bucket name for a given source type."""
return SOURCE_BUCKET_MAP.get(source_type, "stonks-raw-market")
def build_artifact_path(
source_type: str,
ticker: str,
document_id: str,
artifact_name: str = "raw",
ext: str = "json",
timestamp: datetime | None = None,
) -> str:
"""Build a MinIO object path following the design convention.
Pattern: {source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext}
"""
ts = timestamp or datetime.now(timezone.utc)
return (
f"{source_type}/{ticker}/"
f"{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/{artifact_name}.{ext}"
)
def storage_ref(bucket: str, path: str) -> str:
"""Build an s3:// URI for a stored artifact."""
return f"s3://{bucket}/{path}"
def ensure_buckets(client: Minio, buckets: list[str] | None = None) -> list[str]:
"""Create any missing buckets. Returns list of buckets that were created."""
target_buckets = buckets or ALL_BUCKETS
created: list[str] = []
for bucket in target_buckets:
try:
if not client.bucket_exists(bucket):
client.make_bucket(bucket)
created.append(bucket)
logger.info("Created bucket: %s", bucket)
except S3Error as e:
logger.error("Failed to ensure bucket %s: %s", bucket, e)
raise
return created
def upload_artifact(
client: Minio,
bucket: str,
path: str,
data: bytes,
content_type: str = "application/json",
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload raw bytes to MinIO and return the s3:// storage reference.
Args:
client: MinIO client instance.
bucket: Target bucket name.
path: Object path within the bucket.
data: Raw bytes to upload.
content_type: MIME type for the object.
metadata: Optional user metadata to attach to the object.
Returns:
s3:// URI pointing to the uploaded object.
"""
_result = client.put_object(
bucket,
path,
io.BytesIO(data),
length=len(data),
content_type=content_type,
metadata=metadata,
)
ref = storage_ref(bucket, path)
logger.debug("Uploaded %d bytes to %s", len(data), ref)
return ref
def upload_raw_artifact(
client: Minio,
source_type: str,
ticker: str,
document_id: str,
data: bytes,
artifact_type: str = "raw_json",
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload a raw artifact using standard conventions for bucket, path, and content type.
This is the primary entry point for ingestion workers to store raw payloads.
Args:
client: MinIO client instance.
source_type: One of market_api, news_api, filings_api, web_scrape, broker.
ticker: Company ticker symbol.
document_id: Unique document or run identifier.
data: Raw bytes to upload.
artifact_type: One of raw_json, raw_html, raw_text, raw_payload.
timestamp: Override timestamp for path generation (defaults to now UTC).
metadata: Optional user metadata dict.
Returns:
s3:// URI pointing to the uploaded object.
"""
bucket = bucket_for_source(source_type)
ct, ext = ARTIFACT_CONTENT_TYPES.get(artifact_type, ("application/octet-stream", "bin"))
path = build_artifact_path(
source_type=source_type,
ticker=ticker,
document_id=document_id,
artifact_name="raw",
ext=ext,
timestamp=timestamp,
)
return upload_artifact(client, bucket, path, data, content_type=ct, metadata=metadata)
def upload_html_artifact(
client: Minio,
ticker: str,
document_id: str,
html_bytes: bytes,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload raw HTML for a scraped web page.
Stores in stonks-raw-news under the web_scrape source path.
"""
bucket = bucket_for_source("web_scrape")
path = build_artifact_path(
source_type="web_scrape",
ticker=ticker,
document_id=document_id,
artifact_name="raw",
ext="html",
timestamp=timestamp,
)
return upload_artifact(client, bucket, path, html_bytes, content_type="text/html", metadata=metadata)
def upload_normalized_text(
client: Minio,
ticker: str,
document_id: str,
text_bytes: bytes,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload normalized (parsed) text to the stonks-normalized bucket.
Stores under parsed/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/normalized.txt
"""
ts = timestamp or datetime.now(timezone.utc)
path = (
f"parsed/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/normalized.txt"
)
return upload_artifact(
client, "stonks-normalized", path, text_bytes,
content_type="text/plain", metadata=metadata,
)
def upload_parser_output(
client: Minio,
ticker: str,
document_id: str,
output_bytes: bytes,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload structured parser output JSON to the stonks-normalized bucket.
Stores under parsed/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/parser_output.json
"""
ts = timestamp or datetime.now(timezone.utc)
path = (
f"parsed/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/parser_output.json"
)
return upload_artifact(
client, "stonks-normalized", path, output_bytes,
content_type="application/json", metadata=metadata,
)
def upload_extraction_prompt(
client: Minio,
ticker: str,
document_id: str,
prompt_data: bytes,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload the extraction prompt and schema to stonks-llm-prompts.
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/prompt.json
"""
ts = timestamp or datetime.now(timezone.utc)
path = (
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/prompt.json"
)
return upload_artifact(
client, "stonks-llm-prompts", path, prompt_data,
content_type="application/json", metadata=metadata,
)
def upload_extraction_raw_output(
client: Minio,
ticker: str,
document_id: str,
output_data: bytes,
attempt_index: int = 0,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload a raw model output to stonks-llm-results.
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/raw_output_{attempt}.json
"""
ts = timestamp or datetime.now(timezone.utc)
path = (
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/raw_output_{attempt_index}.json"
)
return upload_artifact(
client, "stonks-llm-results", path, output_data,
content_type="application/json", metadata=metadata,
)
def upload_extraction_validation(
client: Minio,
ticker: str,
document_id: str,
validation_data: bytes,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload a validation report to stonks-llm-results.
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/validation.json
"""
ts = timestamp or datetime.now(timezone.utc)
path = (
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/validation.json"
)
return upload_artifact(
client, "stonks-llm-results", path, validation_data,
content_type="application/json", metadata=metadata,
)
def upload_extraction_intelligence(
client: Minio,
ticker: str,
document_id: str,
intelligence_data: bytes,
timestamp: datetime | None = None,
metadata: Mapping[str, str] | None = None,
) -> str:
"""Upload the final intelligence object to stonks-llm-results.
Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/intelligence.json
"""
ts = timestamp or datetime.now(timezone.utc)
path = (
f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/intelligence.json"
)
return upload_artifact(
client, "stonks-llm-results", path, intelligence_data,
content_type="application/json", metadata=metadata,
)
def download_artifact(client: Minio, bucket: str, path: str) -> bytes:
"""Download an artifact from MinIO and return its bytes."""
response = client.get_object(bucket, path)
try:
return response.read()
finally:
response.close()
response.release_conn()
+2
View File
@@ -10,6 +10,7 @@ from pydantic import BaseModel, field_validator
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import setup_logging
config = load_config()
pool: Optional[asyncpg.Pool] = None
@@ -18,6 +19,7 @@ pool: Optional[asyncpg.Pool] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool
setup_logging("symbol_registry", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
yield
await pool.close()
+2 -1
View File
@@ -13,8 +13,8 @@ import asyncpg
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import setup_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("seed")
# --- Seed Companies ---
@@ -173,6 +173,7 @@ async def seed(pool: asyncpg.Pool) -> None:
async def main() -> None:
config = load_config()
setup_logging("seed", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
try:
await seed(pool)