diff --git a/.kiro/specs/stonks-oracle/tasks.md b/.kiro/specs/stonks-oracle/tasks.md index 94fdd1e..9e92d3a 100644 --- a/.kiro/specs/stonks-oracle/tasks.md +++ b/.kiro/specs/stonks-oracle/tasks.md @@ -24,99 +24,153 @@ - [x] Add seed data support for an initial tracked watchlist ## Phase 3 - External API Adapters -- [ ] Implement scheduler for symbol and source polling windows -- [ ] Implement market data API adapter interface -- [ ] Implement first concrete market data provider adapter -- [ ] Implement news API adapter interface -- [ ] Implement first concrete news API provider adapter -- [ ] Implement filings or regulatory adapter interface -- [ ] Implement first concrete filings provider adapter -- [ ] Implement broker API adapter interface for paper trading and order events -- [ ] Implement rate-limit coordination, retries, and backoff across adapters +- [x] Implement scheduler for symbol and source polling windows +- [x] Implement market data API adapter interface +- [x] Implement first concrete market data provider adapter +- [x] Implement news API adapter interface +- [x] Implement first concrete news API provider adapter +- [x] Implement filings or regulatory adapter interface +- [x] Implement first concrete filings provider adapter +- [x] Implement broker API adapter interface for paper trading and order events +- [x] Implement rate-limit coordination, retries, and backoff across adapters ## Phase 4 - Ingestion Pipeline -- [ ] Implement web scraper worker for curated URLs and article pages -- [ ] Implement canonical URL normalization and content hashing -- [ ] Implement raw artifact upload to MinIO -- [ ] Implement metadata persistence in PostgreSQL for market payloads, documents, and broker events -- [ ] Implement retry and failure tracking for source retrieval -- [ ] Implement dedupe logic across article and filing sources +- [x] Implement web scraper worker for curated URLs and article pages +- [x] Implement canonical URL normalization and content hashing +- [x] Implement raw artifact upload to MinIO +- [x] Implement metadata persistence in PostgreSQL for market payloads, documents, and broker events +- [x] Implement retry and failure tracking for source retrieval +- [x] Implement dedupe logic across article and filing sources ## Phase 5 - Parsing and Normalization -- [ ] Implement HTML-to-text parsing pipeline -- [ ] Implement boilerplate reduction and body extraction heuristics -- [ ] Implement parser quality scoring and confidence flags -- [ ] Implement company mention detection using ticker, alias, and name matching -- [ ] Persist normalized text and parser outputs to MinIO and PostgreSQL +- [x] Implement HTML-to-text parsing pipeline +- [x] Implement boilerplate reduction and body extraction heuristics +- [x] Implement parser quality scoring and confidence flags +- [x] Implement company mention detection using ticker, alias, and name matching +- [x] Persist normalized text and parser outputs to MinIO and PostgreSQL ## Phase 6 - Ollama Structured Extraction -- [ ] Build extraction prompt templates with anti-hallucination instructions -- [ ] Build JSON schema definitions for document intelligence extraction -- [ ] Implement Ollama client wrapper using structured output format -- [ ] Implement schema validation and semantic validation layers -- [ ] Persist prompts, model metadata, raw outputs, validation reports, and final intelligence objects -- [ ] Add retry behavior for invalid or incomplete model responses -- [ ] Add model performance metrics and dashboards +- [x] Build extraction prompt templates with anti-hallucination instructions +- [x] Build JSON schema definitions for document intelligence extraction +- [x] Implement Ollama client wrapper using structured output format +- [x] Implement schema validation and semantic validation layers +- [x] Persist prompts, model metadata, raw outputs, validation reports, and final intelligence objects +- [x] Add retry behavior for invalid or incomplete model responses +- [x] Add model performance metrics and dashboards ## Phase 7 - Aggregation and Trend Engine -- [ ] Implement recency decay and source credibility weighting -- [ ] Integrate market context features into aggregation windows -- [ ] Implement company-level rolling window aggregation -- [ ] Implement contradiction detection and disagreement representation -- [ ] Implement sector and market rollups -- [ ] Implement evidence ranking for supporting and opposing documents -- [ ] Persist trend windows and evidence mappings +- [x] Implement recency decay and source credibility weighting +- [x] Integrate market context features into aggregation windows +- [x] Implement company-level rolling window aggregation +- [x] Implement contradiction detection and disagreement representation +- [x] Implement sector and market rollups +- [x] Implement evidence ranking for supporting and opposing documents +- [x] Persist trend windows and evidence mappings ## Phase 8 - Recommendation Engine -- [ ] Design deterministic recommendation eligibility logic -- [ ] Implement recommendation generation from aggregated scores and evidence -- [ ] Add optional LLM wording layer for thesis generation only -- [ ] Persist recommendation objects and evidence citations -- [ ] Add suppression logic for low-quality data or low confidence -- [ ] Publish prediction facts to analytical tables +- [x] Design deterministic recommendation eligibility logic +- [x] Implement recommendation generation from aggregated scores and evidence +- [x] Add optional LLM wording layer for thesis generation only +- [x] Persist recommendation objects and evidence citations +- [x] Add suppression logic for low-quality data or low confidence +- [x] Publish prediction facts to analytical tables ## Phase 9 - Risk Engine and Trade Adapter -- [ ] Implement portfolio and account risk configuration model -- [ ] Implement hard blocks for max position size, sector exposure, daily loss limits, and news-shock lockouts -- [ ] Implement paper trading adapter behavior and state sync -- [ ] Integrate first broker API in sandbox mode -- [ ] Implement idempotent order submission keys and duplicate prevention -- [ ] Implement full execution audit trail -- [ ] Add operator approval workflow for live trading mode -- [ ] Publish order, fill, and position facts to analytical tables +- [x] Implement portfolio and account risk configuration model +- [x] Implement hard blocks for max position size, sector exposure, daily loss limits, and news-shock lockouts +- [x] Implement paper trading adapter behavior and state sync +- [x] Integrate first broker API in sandbox mode +- [x] Implement idempotent order submission keys and duplicate prevention +- [x] Implement full execution audit trail +- [x] Add operator approval workflow for live trading mode +- [x] Publish order, fill, and position facts to analytical tables ## Phase 10 - Lakehouse and SQL Analytics -- [ ] Define analytical fact tables for bars, documents, extractions, signals, orders, fills, positions, and PnL -- [ ] Implement Parquet writers for analytical datasets -- [ ] Implement Hive-compatible partition layout conventions on MinIO -- [ ] Implement Iceberg table creation and metadata management for analytical datasets -- [ ] Implement lake publisher jobs from operational data into analytical fact tables -- [ ] Configure Trino catalogs for Hive and or Iceberg access to MinIO -- [ ] Add example SQL views for prediction-vs-outcome and paper-trade scorecards +- [x] Define analytical fact tables for bars, documents, extractions, signals, orders, fills, positions, and PnL +- [x] Implement Parquet writers for analytical datasets +- [x] Implement Hive-compatible partition layout conventions on MinIO +- [x] Implement Iceberg table creation and metadata management for analytical datasets +- [x] Implement lake publisher jobs from operational data into analytical fact tables +- [x] Configure Trino catalogs for Hive and or Iceberg access to MinIO +- [x] Add example SQL views for prediction-vs-outcome and paper-trade scorecards ## Phase 11 - Query API and Dashboard -- [ ] Build APIs for companies, document timelines, trend summaries, recommendations, and order history -- [ ] Build evidence drill-down view linking recommendations to source documents and raw artifacts -- [ ] Build admin controls for source health, symbol configs, and trading mode -- [ ] Build operational dashboard for ingestion throughput, model failures, and source coverage gaps -- [ ] Build Superset starter dashboards for symbol overview, sentiment heatmap, PnL, and prediction accuracy +- [x] Build APIs for companies, document timelines, trend summaries, recommendations, and order history +- [x] Build evidence drill-down view linking recommendations to source documents and raw artifacts +- [x] Build admin controls for source health, symbol configs, and trading mode +- [x] Build operational dashboard for ingestion throughput, model failures, and source coverage gaps +- [x] Build Superset starter dashboards for symbol overview, sentiment heatmap, PnL, and prediction accuracy ## Phase 12 - Observability and Hardening -- [ ] Add structured logs and distributed tracing across services -- [ ] Add Prometheus metrics for ingestion, parsing, extraction, aggregation, lake publication, and trading -- [ ] Add alerting for source failures, schema failure spikes, analytical lag, and broker issues -- [ ] Add dead-letter queues and replay tooling -- [ ] Add data retention and lifecycle controls for raw and derived artifacts -- [ ] Add security review for secrets, network policies, trading isolation, and dashboard access control +- [x] Add structured logs and distributed tracing across services +- [x] Add Prometheus metrics for ingestion, parsing, extraction, aggregation, lake publication, and trading +- [x] Add alerting for source failures, schema failure spikes, analytical lag, and broker issues +- [x] Add dead-letter queues and replay tooling +- [x] Add data retention and lifecycle controls for raw and derived artifacts +- [x] Add security review for secrets, network policies, trading isolation, and dashboard access control ## Phase 13 - Verification and Rollout -- [ ] Create replay dataset from archived documents for deterministic extraction testing -- [ ] Create integration tests for the full ingest-to-recommendation flow -- [ ] Create paper trading simulation scenarios -- [ ] Validate fail-closed behavior for broker outages and ambiguous order states -- [ ] Validate lake publication and Trino query correctness over partitioned MinIO datasets -- [ ] Run shadow mode before enabling any live execution -- [ ] Prepare operator runbook and incident response procedures +- [x] Create replay dataset from archived documents for deterministic extraction testing +- [x] Create integration tests for the full ingest-to-recommendation flow +- [x] Create paper trading simulation scenarios +- [x] Validate fail-closed behavior for broker outages and ambiguous order states +- [x] Validate lake publication and Trino query correctness over partitioned MinIO datasets +- [x] ~~Run shadow mode~~ moved to Phase 15.5 (post-deployment) +- [x] ~~Prepare operator runbook~~ moved to Phase 15.5 (post-deployment) + +## Phase 14 - Local Docker Build Validation +- [x] 14. Build and validate all Docker containers locally +- [x] 14.1 Build all 11 service containers locally using the Makefile + - Run `make build` to build scheduler, symbol-registry, ingestion, parser, extractor, aggregation, recommendation, risk, broker-adapter, lake-publisher, and query-api images + - Fix any build failures (missing dependencies, import errors, syntax issues) + - _Requirements: N1, 12.1_ +- [x] 14.2 Validate schema and logic consistency across all services + - Run the full test suite with `pytest tests/ -x --tb=short -q` to catch import errors, schema mismatches, and logic inconsistencies + - Verify all shared schemas in `services/shared/schemas.py` are consistent with what each service expects + - Verify config loader fields match the configmap and secrets definitions + - Fix any mismatches found between services, schemas, migrations, and K8s manifests + - _Requirements: 5.2, 5.3, 9.2, N2_ +- [x] 14.3 Verify each container starts without immediate crash + - Run each built image with `docker run --rm` and a quick health check or `--help` flag to confirm the entrypoint resolves + - Fix any runtime import errors or missing module paths + - _Requirements: N1_ + +## Phase 15 - CI Validation, Helm Deployment, and Cluster Rollout +- [-] 15. Commit, push, validate CI, create Helm chart, and deploy to cluster +- [-] 15.1 Commit and push code to GitHub + - Configure git with SSH key for the private repo + - Commit all current changes with message `phase 14-15: docker build validation and helm deployment` + - Push to main branch + - _Requirements: N1_ +- [ ] 15.2 Validate GitHub Actions workflow builds containers + - Monitor the GitHub Actions run to confirm lint-and-test and build-services jobs succeed + - Fix any CI failures and re-push if needed + - _Requirements: N1_ +- [ ] 15.3 Create Helm chart for stonks-oracle deployment + - Create `infra/helm/stonks-oracle/Chart.yaml` with chart metadata + - Create `infra/helm/stonks-oracle/values.yaml` with configurable image tags, replica counts, resource limits, and environment references + - Create Helm templates for all deployments, services, configmap, secrets, ingress, and network policies from existing K8s manifests + - Add imagePullSecrets configuration for GHCR private registry access + - Add a template for a Kubernetes Secret of type `kubernetes.io/dockerconfigjson` for GHCR authentication + - _Requirements: N1, 8.2_ +- [ ] 15.4 Configure GHCR image pull authentication on the cluster + - Create a `docker-registry` secret in the `stonks-oracle` namespace with GHCR credentials (using a GitHub PAT or deploy key) + - Reference the imagePullSecret in all deployment specs via the Helm values + - _Requirements: 8.2, N1_ +- [ ] 15.5 Deploy stonks-oracle to the cluster via Helm + - Run `helm install` or `helm upgrade --install` targeting the `stonks-oracle` namespace + - Verify all pods reach Running/Ready state + - Verify services and ingress endpoints are reachable + - Debug and fix any deployment issues (CrashLoopBackOff, image pull errors, config mismatches) + - _Requirements: N1, 12.1_ +- [ ] 15.6 Run shadow mode before enabling any live execution + - Confirm all services are running and processing in paper-only mode + - Validate end-to-end data flow from ingestion through recommendation without live trades + - _Requirements: N5, 8.1_ +- [ ] 15.7 Prepare operator runbook and incident response procedures + - Document service restart procedures, log access, and common failure modes + - Document how to toggle trading modes and approve live execution + - _Requirements: 8.2, 12.1_ ## Recommended First Vertical Slice - [ ] Track 5 to 10 symbols diff --git a/Makefile b/Makefile index ace5564..8b520d2 100644 --- a/Makefile +++ b/Makefile @@ -24,8 +24,25 @@ test: build: @for svc in $(SERVICES); do \ - echo "Building $$svc..."; \ - docker build -t $(GHCR)/$$svc:$(SHA) -t $(GHCR)/$$svc:latest -f docker/Dockerfile .; \ + case $$svc in \ + scheduler) cmd="python -m services.scheduler.app" ;; \ + symbol-registry) cmd="uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000" ;; \ + ingestion) cmd="python -m services.ingestion.worker" ;; \ + parser) cmd="python -m services.parser.worker" ;; \ + extractor) cmd="python -m services.extractor.main" ;; \ + aggregation) cmd="python -m services.aggregation.main" ;; \ + recommendation) cmd="python -m services.recommendation.main" ;; \ + risk) cmd="uvicorn services.risk.app:app --host 0.0.0.0 --port 8000" ;; \ + broker-adapter) cmd="python -m services.adapters.broker_service" ;; \ + lake-publisher) cmd="python -m services.lake_publisher.jobs" ;; \ + query-api) cmd="uvicorn services.api.app:app --host 0.0.0.0 --port 8000" ;; \ + esac; \ + echo "Building $$svc ($$cmd)..."; \ + docker build \ + --build-arg "SERVICE_CMD=$$cmd" \ + -t $(GHCR)/$$svc:$(SHA) \ + -t $(GHCR)/$$svc:latest \ + -f docker/Dockerfile . || exit 1; \ done push: diff --git a/dashboards/README.md b/dashboards/README.md index 4ded9c8..b96682a 100644 --- a/dashboards/README.md +++ b/dashboards/README.md @@ -3,9 +3,18 @@ Apache Superset dashboard configurations and starter datasets for Stonks Oracle. ## Starter Dashboards -- Symbol Overview — company profile, source health, recent documents -- Sentiment Heatmap — market-wide sentiment by sector and symbol -- Prediction Accuracy — predicted signals vs realized price moves -- Paper Trading PnL — paper trade performance and position tracking -- Model Quality — extraction success rates, latency, and confidence distributions -- Source Coverage — ingestion throughput, source failures, and coverage gaps +See `starter/` for dashboard definitions covering: +- Symbol Overview — company profiles, source health, recent documents, and market snapshots +- Sentiment Heatmap — market-wide sentiment by sector and symbol, catalyst analysis +- Prediction Accuracy — predicted signals vs realized price moves, confidence calibration +- Paper Trading PnL — cumulative PnL, position snapshots, order history, and scorecards + +## Operational Dashboards +See `operational/` for dashboard definitions covering: +- Ingestion Throughput — documents/hour by source type, success/failure rates, stale sources +- Model Extraction Quality — success rates, latency percentiles, validation failures, confidence distributions +- Source Coverage & Gaps — per-symbol source type matrix, missing sources, failure heatmap + +Starter dashboards are powered by the Trino `lakehouse` catalog over MinIO-backed analytical tables. +Operational dashboards query the Query API `/api/ops/*` endpoints. +All dashboards can be imported into Superset via the UI or CLI. diff --git a/dashboards/operational/README.md b/dashboards/operational/README.md new file mode 100644 index 0000000..346e4f2 --- /dev/null +++ b/dashboards/operational/README.md @@ -0,0 +1,22 @@ +# Operational Dashboard + +Superset dashboard definitions for Stonks Oracle operational monitoring. + +## Dashboards +- Ingestion Throughput — documents ingested per hour by source type, success/failure rates +- Model Extraction Quality — extraction success rates, latency percentiles, validation failures +- Source Coverage Gaps — symbols missing source types, stale sources with no recent data + +## Data Sources +These dashboards query the Query API operational endpoints: +- `/api/ops/ingestion/throughput` — time-bucketed ingestion metrics +- `/api/ops/ingestion/summary` — aggregate ingestion stats +- `/api/ops/model/failures` — recent extraction failures +- `/api/ops/model/performance` — model performance summary +- `/api/ops/pipeline/health` — pipeline stage health +- `/api/ops/sources/coverage-gaps` — source coverage analysis + +## Setup +Import the dashboard JSON files into Superset via the Superset UI or CLI. +The dashboards use the Trino `lakehouse` catalog as their primary datasource, +with supplementary queries against the Query API for real-time operational data. diff --git a/dashboards/operational/ingestion_throughput.json b/dashboards/operational/ingestion_throughput.json new file mode 100644 index 0000000..c835c02 --- /dev/null +++ b/dashboards/operational/ingestion_throughput.json @@ -0,0 +1,75 @@ +{ + "dashboard_title": "Ingestion Throughput", + "description": "Operational dashboard for monitoring ingestion pipeline throughput, success rates, and item counts across source types.", + "slug": "ingestion-throughput", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Ingestion Throughput"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-throughput-timeseries", "CHART-source-type-breakdown"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-success-failure-rate", "CHART-items-fetched"] + }, + "ROW-3": { + "type": "ROW", + "children": ["CHART-stale-sources", "CHART-active-companies"] + } + }, + "metadata": { + "refresh_frequency": 300, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Ingestion Runs Over Time", + "viz_type": "echarts_timeseries_bar", + "description": "Ingestion run counts bucketed by hour, stacked by source type", + "datasource_type": "query", + "query": "SELECT date_trunc('hour', ir.started_at) AS bucket, ir.source_type, COUNT(*) AS run_count, COUNT(*) FILTER (WHERE ir.status = 'completed') AS completed, COUNT(*) FILTER (WHERE ir.status = 'failed') AS failed FROM ingestion_runs ir WHERE ir.started_at >= NOW() - INTERVAL '24 hours' GROUP BY 1, 2 ORDER BY 1", + "params": { + "x_axis": "bucket", + "metrics": ["run_count"], + "groupby": ["source_type"], + "time_grain_sqla": "PT1H" + } + }, + { + "slice_name": "Source Type Breakdown", + "viz_type": "pie", + "description": "Distribution of ingestion runs by source type in the last 24h", + "datasource_type": "query", + "query": "SELECT ir.source_type, COUNT(*) AS runs FROM ingestion_runs ir WHERE ir.started_at >= NOW() - INTERVAL '24 hours' GROUP BY ir.source_type ORDER BY runs DESC" + }, + { + "slice_name": "Success vs Failure Rate", + "viz_type": "echarts_timeseries_line", + "description": "Hourly success and failure counts over time", + "datasource_type": "query", + "query": "SELECT date_trunc('hour', ir.started_at) AS bucket, COUNT(*) FILTER (WHERE ir.status = 'completed') AS completed, COUNT(*) FILTER (WHERE ir.status = 'failed') AS failed, ROUND(COUNT(*) FILTER (WHERE ir.status = 'completed')::numeric / NULLIF(COUNT(*), 0), 3) AS success_rate FROM ingestion_runs ir WHERE ir.started_at >= NOW() - INTERVAL '24 hours' GROUP BY 1 ORDER BY 1" + }, + { + "slice_name": "Items Fetched Over Time", + "viz_type": "echarts_timeseries_bar", + "description": "Total items fetched and new items per hour", + "datasource_type": "query", + "query": "SELECT date_trunc('hour', ir.started_at) AS bucket, COALESCE(SUM(ir.items_fetched), 0) AS items_fetched, COALESCE(SUM(ir.items_new), 0) AS items_new FROM ingestion_runs ir WHERE ir.started_at >= NOW() - INTERVAL '24 hours' GROUP BY 1 ORDER BY 1" + }, + { + "slice_name": "Stale Sources", + "viz_type": "table", + "description": "Sources with no successful run in the last 24 hours", + "datasource_type": "query", + "query": "SELECT c.ticker, s.source_type, s.source_name, MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') AS last_success, COUNT(*) FILTER (WHERE ir.status = 'failed' AND ir.started_at >= NOW() - INTERVAL '24 hours') AS recent_failures FROM sources s JOIN companies c ON c.id = s.company_id LEFT JOIN ingestion_runs ir ON ir.source_id = s.id WHERE s.active = TRUE AND c.active = TRUE GROUP BY c.ticker, s.source_type, s.source_name HAVING MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') < NOW() - INTERVAL '24 hours' OR MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') IS NULL ORDER BY c.ticker" + }, + { + "slice_name": "Active Companies Ingested", + "viz_type": "big_number_total", + "description": "Count of distinct companies with ingestion activity in the last 24h", + "datasource_type": "query", + "query": "SELECT COUNT(DISTINCT company_id) AS active_companies FROM ingestion_runs WHERE started_at >= NOW() - INTERVAL '24 hours'" + } + ] +} diff --git a/dashboards/operational/model_quality.json b/dashboards/operational/model_quality.json new file mode 100644 index 0000000..4c872be --- /dev/null +++ b/dashboards/operational/model_quality.json @@ -0,0 +1,94 @@ +{ + "dashboard_title": "Model Extraction Quality", + "description": "Operational dashboard for monitoring Ollama extraction success rates, latency, validation failures, and confidence distributions.", + "slug": "model-extraction-quality", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Model Extraction Quality"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-success-rate-kpi", "CHART-avg-latency-kpi", "CHART-avg-confidence-kpi", "CHART-retry-rate-kpi"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-extraction-timeseries", "CHART-validation-status-pie"] + }, + "ROW-3": { + "type": "ROW", + "children": ["CHART-latency-percentiles", "CHART-confidence-distribution"] + }, + "ROW-4": { + "type": "ROW", + "children": ["CHART-recent-failures-table"] + } + }, + "metadata": { + "refresh_frequency": 300, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Extraction Success Rate", + "viz_type": "big_number_total", + "description": "Overall extraction success rate in the last 24h", + "datasource_type": "query", + "query": "SELECT ROUND(COUNT(*) FILTER (WHERE success)::numeric / NULLIF(COUNT(*), 0), 4) AS success_rate FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours'" + }, + { + "slice_name": "Avg Extraction Latency", + "viz_type": "big_number_total", + "description": "Average extraction duration in milliseconds", + "datasource_type": "query", + "query": "SELECT ROUND(AVG(total_duration_ms)::numeric, 0) AS avg_latency_ms FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours'" + }, + { + "slice_name": "Avg Confidence Score", + "viz_type": "big_number_total", + "description": "Average confidence of successful extractions", + "datasource_type": "query", + "query": "SELECT ROUND(AVG(confidence)::numeric, 3) AS avg_confidence FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours' AND success = TRUE" + }, + { + "slice_name": "Avg Retry Count", + "viz_type": "big_number_total", + "description": "Average retries per extraction attempt", + "datasource_type": "query", + "query": "SELECT ROUND(AVG(retry_count)::numeric, 2) AS avg_retries FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours'" + }, + { + "slice_name": "Extractions Over Time", + "viz_type": "echarts_timeseries_bar", + "description": "Hourly extraction counts split by success/failure", + "datasource_type": "query", + "query": "SELECT date_trunc('hour', recorded_at) AS bucket, COUNT(*) FILTER (WHERE success) AS successful, COUNT(*) FILTER (WHERE NOT success) AS failed FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours' GROUP BY 1 ORDER BY 1" + }, + { + "slice_name": "Validation Status Distribution", + "viz_type": "pie", + "description": "Breakdown of extraction validation outcomes", + "datasource_type": "query", + "query": "SELECT validation_status, COUNT(*) AS count FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours' GROUP BY validation_status" + }, + { + "slice_name": "Latency Percentiles Over Time", + "viz_type": "echarts_timeseries_line", + "description": "P50, P95, P99 extraction latency per hour", + "datasource_type": "query", + "query": "SELECT date_trunc('hour', recorded_at) AS bucket, ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 0) AS p50_ms, ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 0) AS p95_ms, ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 0) AS p99_ms FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours' GROUP BY 1 ORDER BY 1" + }, + { + "slice_name": "Confidence Distribution", + "viz_type": "histogram", + "description": "Distribution of extraction confidence scores", + "datasource_type": "query", + "query": "SELECT CASE WHEN confidence >= 0.9 THEN '0.9-1.0' WHEN confidence >= 0.8 THEN '0.8-0.9' WHEN confidence >= 0.7 THEN '0.7-0.8' WHEN confidence >= 0.6 THEN '0.6-0.7' WHEN confidence >= 0.5 THEN '0.5-0.6' ELSE '<0.5' END AS confidence_bucket, COUNT(*) AS count FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '24 hours' AND success = TRUE GROUP BY 1 ORDER BY 1" + }, + { + "slice_name": "Recent Extraction Failures", + "viz_type": "table", + "description": "Most recent failed extractions with error details", + "datasource_type": "query", + "query": "SELECT mpm.ticker, mpm.model_name, mpm.validation_status, mpm.validation_error_count, mpm.attempt_count, mpm.total_duration_ms, mpm.recorded_at, d.title, d.document_type FROM model_performance_metrics mpm LEFT JOIN documents d ON d.id = mpm.document_id WHERE mpm.success = FALSE AND mpm.recorded_at >= NOW() - INTERVAL '24 hours' ORDER BY mpm.recorded_at DESC LIMIT 50" + } + ] +} diff --git a/dashboards/operational/source_coverage.json b/dashboards/operational/source_coverage.json new file mode 100644 index 0000000..b7d8946 --- /dev/null +++ b/dashboards/operational/source_coverage.json @@ -0,0 +1,51 @@ +{ + "dashboard_title": "Source Coverage & Gaps", + "description": "Operational dashboard for identifying source coverage gaps, stale sources, and symbols missing expected data feeds.", + "slug": "source-coverage-gaps", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Source Coverage & Gaps"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-coverage-matrix", "CHART-missing-types-table"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-stale-sources-table", "CHART-failure-heatmap"] + } + }, + "metadata": { + "refresh_frequency": 600, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Source Coverage Matrix", + "viz_type": "table", + "description": "Per-symbol source type coverage showing active source counts", + "datasource_type": "query", + "query": "SELECT c.ticker, c.legal_name, c.sector, COUNT(s.id) FILTER (WHERE s.active) AS active_sources, COUNT(s.id) FILTER (WHERE s.source_type = 'market_api' AND s.active) AS market_sources, COUNT(s.id) FILTER (WHERE s.source_type = 'news_api' AND s.active) AS news_sources, COUNT(s.id) FILTER (WHERE s.source_type = 'filings_api' AND s.active) AS filings_sources, COUNT(s.id) FILTER (WHERE s.source_type = 'web_scrape' AND s.active) AS web_scrape_sources, COUNT(s.id) FILTER (WHERE s.source_type = 'broker' AND s.active) AS broker_sources FROM companies c LEFT JOIN sources s ON s.company_id = c.id WHERE c.active = TRUE GROUP BY c.ticker, c.legal_name, c.sector ORDER BY c.ticker" + }, + { + "slice_name": "Symbols Missing Source Types", + "viz_type": "table", + "description": "Companies that lack one or more expected source types (market_api, news_api, filings_api)", + "datasource_type": "query", + "query": "SELECT c.ticker, c.legal_name, c.sector, ARRAY_AGG(DISTINCT s.source_type) FILTER (WHERE s.active) AS active_types FROM companies c LEFT JOIN sources s ON s.company_id = c.id AND s.active = TRUE WHERE c.active = TRUE GROUP BY c.ticker, c.legal_name, c.sector HAVING NOT ARRAY['market_api', 'news_api', 'filings_api'] <@ ARRAY_AGG(DISTINCT s.source_type) FILTER (WHERE s.active) OR ARRAY_AGG(DISTINCT s.source_type) FILTER (WHERE s.active) IS NULL ORDER BY c.ticker" + }, + { + "slice_name": "Stale Sources (No Success in 24h)", + "viz_type": "table", + "description": "Active sources that have not completed a successful ingestion run in the last 24 hours", + "datasource_type": "query", + "query": "SELECT c.ticker, s.source_type, s.source_name, MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') AS last_success, MAX(ir.started_at) AS last_attempt, COUNT(*) FILTER (WHERE ir.status = 'failed' AND ir.started_at >= NOW() - INTERVAL '24 hours') AS recent_failures FROM sources s JOIN companies c ON c.id = s.company_id LEFT JOIN ingestion_runs ir ON ir.source_id = s.id WHERE s.active = TRUE AND c.active = TRUE GROUP BY c.ticker, s.source_type, s.source_name HAVING MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') < NOW() - INTERVAL '24 hours' OR MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') IS NULL ORDER BY c.ticker, s.source_type" + }, + { + "slice_name": "Source Failure Heatmap", + "viz_type": "heatmap", + "description": "Failure counts by source type and ticker in the last 24h", + "datasource_type": "query", + "query": "SELECT c.ticker, ir.source_type, COUNT(*) FILTER (WHERE ir.status = 'failed') AS failures FROM ingestion_runs ir JOIN companies c ON c.id = ir.company_id WHERE ir.started_at >= NOW() - INTERVAL '24 hours' GROUP BY c.ticker, ir.source_type HAVING COUNT(*) FILTER (WHERE ir.status = 'failed') > 0 ORDER BY failures DESC" + } + ] +} diff --git a/dashboards/starter/README.md b/dashboards/starter/README.md new file mode 100644 index 0000000..ef28757 --- /dev/null +++ b/dashboards/starter/README.md @@ -0,0 +1,29 @@ +# Starter Dashboards + +Superset dashboard definitions for Stonks Oracle research, analysis, and trading review. + +## Dashboards +- Symbol Overview — company profiles, source health, recent documents, and market snapshots +- Sentiment Heatmap — market-wide sentiment by sector and symbol, catalyst analysis, contradiction tracking +- Prediction Accuracy — predicted signals vs realized price moves, confidence calibration, per-symbol accuracy +- Paper Trading PnL — cumulative PnL, daily performance, position snapshots, order history, and scorecards + +## Data Sources +These dashboards query the Trino `lakehouse` catalog over MinIO-backed analytical fact tables: +- `lakehouse.stonks.documents` — ingested document metadata +- `lakehouse.stonks.document_extractions` — AI extraction outputs +- `lakehouse.stonks.trade_signals` — aggregated trend signals +- `lakehouse.stonks.market_bars` — OHLCV bar data +- `lakehouse.stonks.prediction_vs_outcome` — prediction accuracy tracking +- `lakehouse.stonks.pnl_daily` — daily PnL records +- `lakehouse.stonks.positions_daily` — end-of-day position snapshots +- `lakehouse.stonks.trade_orders` — order submission records +- `lakehouse.stonks.trade_fills` — fill and execution records + +## Setup +1. Import the dashboard JSON files into Superset via the Superset UI or CLI +2. Ensure the Trino datasource is configured: `trino://trino@trino:8080/lakehouse/stonks` +3. Create the lakehouse views from `lakehouse/views/` for additional drill-down capability + +## Trino Connection +The dashboards use the default Superset Trino connection configured in `infra/superset/superset_config.py`. diff --git a/dashboards/starter/paper_trading_pnl.json b/dashboards/starter/paper_trading_pnl.json new file mode 100644 index 0000000..a71346a --- /dev/null +++ b/dashboards/starter/paper_trading_pnl.json @@ -0,0 +1,124 @@ +{ + "dashboard_title": "Paper Trading PnL", + "description": "Paper trading performance tracking with PnL curves, position snapshots, order history, and trade detail drill-down.", + "slug": "paper-trading-pnl", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Paper Trading PnL"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-total-net-pnl-kpi", "CHART-win-rate-kpi", "CHART-total-orders-kpi", "CHART-active-positions-kpi"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-cumulative-pnl-timeseries", "CHART-daily-pnl-bar"] + }, + "ROW-3": { + "type": "ROW", + "children": ["CHART-pnl-by-symbol", "CHART-order-status-pie"] + }, + "ROW-4": { + "type": "ROW", + "children": ["CHART-positions-table"] + }, + "ROW-5": { + "type": "ROW", + "children": ["CHART-scorecard-table"] + }, + "ROW-6": { + "type": "ROW", + "children": ["CHART-recent-orders-table"] + } + }, + "metadata": { + "refresh_frequency": 300, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Total Net PnL", + "viz_type": "big_number_total", + "description": "Cumulative net PnL across all paper trading activity", + "datasource_type": "trino", + "query": "SELECT ROUND(SUM(net_pnl), 2) AS total_net_pnl FROM lakehouse.stonks.pnl_daily WHERE execution_mode = 'paper'" + }, + { + "slice_name": "Win Rate", + "viz_type": "big_number_total", + "description": "Fraction of trading days with positive net PnL", + "datasource_type": "trino", + "query": "SELECT ROUND(CAST(COUNT(CASE WHEN net_pnl > 0 THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS win_rate FROM lakehouse.stonks.pnl_daily WHERE execution_mode = 'paper'" + }, + { + "slice_name": "Total Orders", + "viz_type": "big_number_total", + "description": "Total paper trade orders submitted", + "datasource_type": "trino", + "query": "SELECT COUNT(DISTINCT order_id) AS total_orders FROM lakehouse.stonks.trade_orders WHERE execution_mode = 'paper'" + }, + { + "slice_name": "Active Positions", + "viz_type": "big_number_total", + "description": "Number of symbols with open positions as of the latest snapshot", + "datasource_type": "trino", + "query": "SELECT COUNT(DISTINCT ticker) AS active_positions FROM lakehouse.stonks.positions_daily WHERE execution_mode = 'paper' AND quantity <> 0 AND dt = (SELECT MAX(dt) FROM lakehouse.stonks.positions_daily WHERE execution_mode = 'paper')" + }, + { + "slice_name": "Cumulative PnL Over Time", + "viz_type": "echarts_timeseries_line", + "description": "Running cumulative net PnL across all paper trades", + "datasource_type": "trino", + "query": "SELECT dt AS bucket, SUM(net_pnl) AS daily_net_pnl, SUM(SUM(net_pnl)) OVER (ORDER BY dt) AS cumulative_pnl FROM lakehouse.stonks.pnl_daily WHERE execution_mode = 'paper' GROUP BY dt ORDER BY dt" + }, + { + "slice_name": "Daily PnL", + "viz_type": "echarts_timeseries_bar", + "description": "Daily net PnL for paper trading, colored by positive/negative", + "datasource_type": "trino", + "query": "SELECT dt AS bucket, ROUND(SUM(net_pnl), 2) AS daily_pnl, ROUND(SUM(realized_pnl), 2) AS realized, ROUND(SUM(unrealized_pnl), 2) AS unrealized FROM lakehouse.stonks.pnl_daily WHERE execution_mode = 'paper' GROUP BY dt ORDER BY dt", + "params": { + "x_axis": "bucket", + "metrics": ["daily_pnl"] + } + }, + { + "slice_name": "PnL by Symbol", + "viz_type": "echarts_timeseries_bar", + "description": "Total net PnL per symbol for paper trading", + "datasource_type": "trino", + "query": "SELECT ticker, ROUND(SUM(net_pnl), 2) AS total_pnl, ROUND(SUM(realized_pnl), 2) AS realized_pnl, ROUND(SUM(fees), 2) AS total_fees FROM lakehouse.stonks.pnl_daily WHERE execution_mode = 'paper' GROUP BY ticker ORDER BY total_pnl DESC", + "params": { + "x_axis": "ticker", + "metrics": ["total_pnl"] + } + }, + { + "slice_name": "Order Status Distribution", + "viz_type": "pie", + "description": "Breakdown of paper trade order statuses", + "datasource_type": "trino", + "query": "SELECT status, COUNT(*) AS count FROM lakehouse.stonks.trade_orders WHERE execution_mode = 'paper' GROUP BY status ORDER BY count DESC" + }, + { + "slice_name": "Current Positions", + "viz_type": "table", + "description": "Latest position snapshot for all paper trading symbols", + "datasource_type": "trino", + "query": "SELECT p.ticker, p.quantity, ROUND(p.avg_entry_price, 2) AS avg_entry, ROUND(p.close_price, 2) AS close_price, ROUND(p.market_value, 2) AS market_value, ROUND(p.unrealized_pnl, 2) AS unrealized_pnl, p.snapshot_at FROM lakehouse.stonks.positions_daily p WHERE p.execution_mode = 'paper' AND p.dt = (SELECT MAX(dt) FROM lakehouse.stonks.positions_daily WHERE execution_mode = 'paper') ORDER BY ABS(p.unrealized_pnl) DESC" + }, + { + "slice_name": "Paper Trade Scorecard", + "viz_type": "table", + "description": "Per-symbol paper trading scorecard with win rates, PnL, and order counts", + "datasource_type": "trino", + "query": "SELECT pnl.ticker, COUNT(DISTINCT pnl.dt) AS trading_days, ROUND(SUM(pnl.net_pnl), 2) AS total_net_pnl, ROUND(AVG(pnl.net_pnl), 2) AS avg_daily_pnl, ROUND(CAST(COUNT(CASE WHEN pnl.net_pnl > 0 THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS win_rate, ROUND(MIN(pnl.net_pnl), 2) AS worst_day, ROUND(MAX(pnl.net_pnl), 2) AS best_day, ROUND(SUM(pnl.fees), 2) AS total_fees, MIN(pnl.dt) AS first_trade, MAX(pnl.dt) AS last_trade FROM lakehouse.stonks.pnl_daily pnl WHERE pnl.execution_mode = 'paper' GROUP BY pnl.ticker ORDER BY total_net_pnl DESC" + }, + { + "slice_name": "Recent Orders", + "viz_type": "table", + "description": "Most recent paper trade orders with fill details", + "datasource_type": "trino", + "query": "SELECT o.ticker, o.side, o.order_type, o.quantity, ROUND(o.limit_price, 2) AS limit_price, o.status, f.fill_price, f.fill_quantity, f.commission, o.submitted_at, f.filled_at FROM lakehouse.stonks.trade_orders o LEFT JOIN lakehouse.stonks.trade_fills f ON o.order_id = f.order_id AND o.dt = f.dt WHERE o.execution_mode = 'paper' ORDER BY o.submitted_at DESC LIMIT 50" + } + ] +} diff --git a/dashboards/starter/prediction_accuracy.json b/dashboards/starter/prediction_accuracy.json new file mode 100644 index 0000000..b8a9426 --- /dev/null +++ b/dashboards/starter/prediction_accuracy.json @@ -0,0 +1,125 @@ +{ + "dashboard_title": "Prediction Accuracy", + "description": "Predicted signals vs realized price moves, confidence calibration, and model accuracy tracking.", + "slug": "prediction-accuracy", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Prediction Accuracy"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-overall-hit-rate-kpi", "CHART-total-predictions-kpi", "CHART-avg-confidence-kpi", "CHART-avg-move-kpi"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-hit-rate-timeseries", "CHART-outcome-distribution-pie"] + }, + "ROW-3": { + "type": "ROW", + "children": ["CHART-confidence-calibration", "CHART-confidence-vs-move-scatter"] + }, + "ROW-4": { + "type": "ROW", + "children": ["CHART-accuracy-by-symbol", "CHART-accuracy-by-action"] + }, + "ROW-5": { + "type": "ROW", + "children": ["CHART-recent-predictions-table"] + } + }, + "metadata": { + "refresh_frequency": 600, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Overall Hit Rate", + "viz_type": "big_number_total", + "description": "Fraction of predictions with correct directional outcome over the last 30 days", + "datasource_type": "trino", + "query": "SELECT ROUND(CAST(COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS hit_rate FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Total Predictions (30d)", + "viz_type": "big_number_total", + "description": "Total evaluated predictions in the last 30 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS total_predictions FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Avg Predicted Confidence", + "viz_type": "big_number_total", + "description": "Average confidence of predictions in the last 30 days", + "datasource_type": "trino", + "query": "SELECT ROUND(AVG(predicted_confidence), 3) AS avg_confidence FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Avg Realized Move", + "viz_type": "big_number_total", + "description": "Average absolute realized price move percentage", + "datasource_type": "trino", + "query": "SELECT ROUND(AVG(ABS(actual_move_pct)), 3) AS avg_abs_move FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Daily Hit Rate", + "viz_type": "echarts_timeseries_line", + "description": "Daily prediction hit rate over the last 30 days", + "datasource_type": "trino", + "query": "SELECT dt AS bucket, COUNT(*) AS total, COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS correct, ROUND(CAST(COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS hit_rate FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY dt ORDER BY dt" + }, + { + "slice_name": "Outcome Distribution", + "viz_type": "pie", + "description": "Breakdown of prediction outcomes (correct, incorrect, neutral) over the last 30 days", + "datasource_type": "trino", + "query": "SELECT outcome, COUNT(*) AS count FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY outcome ORDER BY count DESC" + }, + { + "slice_name": "Confidence Calibration", + "viz_type": "echarts_timeseries_bar", + "description": "Hit rate by confidence bucket to assess calibration quality", + "datasource_type": "trino", + "query": "SELECT CASE WHEN predicted_confidence >= 0.8 THEN '0.8-1.0 (high)' WHEN predicted_confidence >= 0.6 THEN '0.6-0.8 (medium)' WHEN predicted_confidence >= 0.4 THEN '0.4-0.6 (low)' ELSE '0.0-0.4 (very low)' END AS confidence_bucket, COUNT(*) AS total, COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS correct, ROUND(CAST(COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS hit_rate FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY 1 ORDER BY 1", + "params": { + "x_axis": "confidence_bucket", + "metrics": ["hit_rate"] + } + }, + { + "slice_name": "Confidence vs Realized Move", + "viz_type": "echarts_timeseries_scatter", + "description": "Scatter plot of predicted confidence vs actual realized move percentage", + "datasource_type": "trino", + "query": "SELECT ticker, predicted_confidence, actual_move_pct, predicted_action, outcome, dt FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY ORDER BY dt DESC", + "params": { + "x_axis": "predicted_confidence", + "y_axis": "actual_move_pct", + "groupby": ["outcome"] + } + }, + { + "slice_name": "Accuracy by Symbol", + "viz_type": "table", + "description": "Per-symbol prediction accuracy summary", + "datasource_type": "trino", + "query": "SELECT ticker, COUNT(*) AS predictions, COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS correct, COUNT(CASE WHEN outcome = 'incorrect' THEN 1 END) AS incorrect, ROUND(CAST(COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS hit_rate, ROUND(AVG(predicted_confidence), 3) AS avg_confidence, ROUND(AVG(actual_move_pct), 3) AS avg_move_pct, ROUND(AVG(ABS(actual_move_pct)), 3) AS avg_abs_move_pct FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY ticker ORDER BY hit_rate DESC" + }, + { + "slice_name": "Accuracy by Action Type", + "viz_type": "echarts_timeseries_bar", + "description": "Hit rate broken down by predicted action (buy, sell, hold, watch)", + "datasource_type": "trino", + "query": "SELECT predicted_action, COUNT(*) AS total, COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS correct, ROUND(CAST(COUNT(CASE WHEN outcome = 'correct' THEN 1 END) AS DOUBLE) / NULLIF(COUNT(*), 0), 4) AS hit_rate, ROUND(AVG(predicted_confidence), 3) AS avg_confidence FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY predicted_action ORDER BY predicted_action", + "params": { + "x_axis": "predicted_action", + "metrics": ["hit_rate"] + } + }, + { + "slice_name": "Recent Predictions", + "viz_type": "table", + "description": "Most recent evaluated predictions with outcomes", + "datasource_type": "trino", + "query": "SELECT ticker, predicted_action, ROUND(predicted_confidence, 3) AS confidence, ROUND(actual_move_pct, 3) AS actual_move_pct, outcome, horizon_days, model_version, predicted_at, evaluated_at FROM lakehouse.stonks.prediction_vs_outcome WHERE dt >= CURRENT_DATE - INTERVAL '14' DAY ORDER BY evaluated_at DESC LIMIT 50" + } + ] +} diff --git a/dashboards/starter/sentiment_heatmap.json b/dashboards/starter/sentiment_heatmap.json new file mode 100644 index 0000000..7b01b59 --- /dev/null +++ b/dashboards/starter/sentiment_heatmap.json @@ -0,0 +1,120 @@ +{ + "dashboard_title": "Sentiment Heatmap", + "description": "Market-wide sentiment visualization by sector and symbol, with trend direction and catalyst analysis.", + "slug": "sentiment-heatmap", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Sentiment Heatmap"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-bullish-count-kpi", "CHART-bearish-count-kpi", "CHART-mixed-count-kpi", "CHART-avg-contradiction-kpi"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-sentiment-heatmap"] + }, + "ROW-3": { + "type": "ROW", + "children": ["CHART-sentiment-timeseries", "CHART-catalyst-breakdown"] + }, + "ROW-4": { + "type": "ROW", + "children": ["CHART-contradiction-scatter", "CHART-sentiment-distribution"] + }, + "ROW-5": { + "type": "ROW", + "children": ["CHART-symbol-sentiment-detail"] + } + }, + "metadata": { + "refresh_frequency": 300, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Bullish Signals (7d)", + "viz_type": "big_number_total", + "description": "Count of bullish trend signals in the last 7 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS bullish_count FROM lakehouse.stonks.trade_signals WHERE trend_direction = 'bullish' AND dt >= CURRENT_DATE - INTERVAL '7' DAY" + }, + { + "slice_name": "Bearish Signals (7d)", + "viz_type": "big_number_total", + "description": "Count of bearish trend signals in the last 7 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS bearish_count FROM lakehouse.stonks.trade_signals WHERE trend_direction = 'bearish' AND dt >= CURRENT_DATE - INTERVAL '7' DAY" + }, + { + "slice_name": "Mixed Signals (7d)", + "viz_type": "big_number_total", + "description": "Count of mixed or neutral trend signals in the last 7 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS mixed_count FROM lakehouse.stonks.trade_signals WHERE trend_direction IN ('mixed', 'neutral') AND dt >= CURRENT_DATE - INTERVAL '7' DAY" + }, + { + "slice_name": "Avg Contradiction Score (7d)", + "viz_type": "big_number_total", + "description": "Average contradiction score across all signals in the last 7 days", + "datasource_type": "trino", + "query": "SELECT ROUND(AVG(contradiction_score), 3) AS avg_contradiction FROM lakehouse.stonks.trade_signals WHERE dt >= CURRENT_DATE - INTERVAL '7' DAY" + }, + { + "slice_name": "Sentiment Heatmap by Symbol", + "viz_type": "heatmap", + "description": "Daily average sentiment impact score by symbol over the last 14 days", + "datasource_type": "trino", + "query": "SELECT de.ticker, de.dt, ROUND(AVG(de.impact_score), 3) AS avg_impact, AVG(CASE WHEN de.sentiment = 'positive' THEN 1.0 WHEN de.sentiment = 'negative' THEN -1.0 ELSE 0.0 END) AS sentiment_score FROM lakehouse.stonks.document_extractions de WHERE de.dt >= CURRENT_DATE - INTERVAL '14' DAY GROUP BY de.ticker, de.dt ORDER BY de.ticker, de.dt", + "params": { + "x_axis": "dt", + "y_axis": "ticker", + "metric": "sentiment_score" + } + }, + { + "slice_name": "Sentiment Trend Over Time", + "viz_type": "echarts_timeseries_line", + "description": "Daily average sentiment score across all symbols over the last 30 days", + "datasource_type": "trino", + "query": "SELECT de.dt AS bucket, ROUND(AVG(CASE WHEN de.sentiment = 'positive' THEN 1.0 WHEN de.sentiment = 'negative' THEN -1.0 ELSE 0.0 END), 3) AS avg_sentiment, COUNT(*) AS extraction_count FROM lakehouse.stonks.document_extractions de WHERE de.dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY de.dt ORDER BY de.dt" + }, + { + "slice_name": "Catalyst Type Breakdown", + "viz_type": "pie", + "description": "Distribution of catalyst types across extractions in the last 14 days", + "datasource_type": "trino", + "query": "SELECT catalyst_type, COUNT(*) AS count FROM lakehouse.stonks.document_extractions WHERE dt >= CURRENT_DATE - INTERVAL '14' DAY AND catalyst_type IS NOT NULL GROUP BY catalyst_type ORDER BY count DESC" + }, + { + "slice_name": "Contradiction vs Confidence", + "viz_type": "echarts_timeseries_scatter", + "description": "Scatter of contradiction score vs confidence for recent signals", + "datasource_type": "trino", + "query": "SELECT ticker, confidence, contradiction_score, trend_strength, trend_direction, dt FROM lakehouse.stonks.trade_signals WHERE dt >= CURRENT_DATE - INTERVAL '14' DAY ORDER BY dt DESC", + "params": { + "x_axis": "confidence", + "y_axis": "contradiction_score", + "groupby": ["trend_direction"] + } + }, + { + "slice_name": "Sentiment Distribution by Symbol", + "viz_type": "echarts_timeseries_bar", + "description": "Count of positive, negative, and neutral extractions per symbol in the last 14 days", + "datasource_type": "trino", + "query": "SELECT ticker, sentiment, COUNT(*) AS count FROM lakehouse.stonks.document_extractions WHERE dt >= CURRENT_DATE - INTERVAL '14' DAY GROUP BY ticker, sentiment ORDER BY ticker, sentiment", + "params": { + "x_axis": "ticker", + "metrics": ["count"], + "groupby": ["sentiment"] + } + }, + { + "slice_name": "Symbol Sentiment Detail", + "viz_type": "table", + "description": "Per-symbol sentiment summary with extraction counts, average impact, and dominant catalysts", + "datasource_type": "trino", + "query": "SELECT de.ticker, COUNT(*) AS extractions, ROUND(AVG(de.impact_score), 3) AS avg_impact, ROUND(AVG(de.confidence), 3) AS avg_confidence, ROUND(AVG(de.novelty_score), 3) AS avg_novelty, COUNT(CASE WHEN de.sentiment = 'positive' THEN 1 END) AS positive_count, COUNT(CASE WHEN de.sentiment = 'negative' THEN 1 END) AS negative_count, COUNT(CASE WHEN de.sentiment = 'neutral' THEN 1 END) AS neutral_count, ts.trend_direction AS latest_trend, ts.trend_strength AS latest_trend_strength FROM lakehouse.stonks.document_extractions de LEFT JOIN lakehouse.stonks.trade_signals ts ON de.ticker = ts.ticker AND ts.dt = (SELECT MAX(dt) FROM lakehouse.stonks.trade_signals WHERE ticker = de.ticker) WHERE de.dt >= CURRENT_DATE - INTERVAL '14' DAY GROUP BY de.ticker, ts.trend_direction, ts.trend_strength ORDER BY de.ticker" + } + ] +} diff --git a/dashboards/starter/symbol_overview.json b/dashboards/starter/symbol_overview.json new file mode 100644 index 0000000..f534a8c --- /dev/null +++ b/dashboards/starter/symbol_overview.json @@ -0,0 +1,104 @@ +{ + "dashboard_title": "Symbol Overview", + "description": "Company profiles, source health, recent documents, and market snapshot for tracked symbols.", + "slug": "symbol-overview", + "position_json": { + "HEADER_ID": {"id": "HEADER_ID", "type": "HEADER", "meta": {"text": "Symbol Overview"}}, + "ROW-1": { + "type": "ROW", + "children": ["CHART-tracked-symbols-kpi", "CHART-total-documents-kpi", "CHART-total-extractions-kpi", "CHART-active-signals-kpi"] + }, + "ROW-2": { + "type": "ROW", + "children": ["CHART-company-summary-table"] + }, + "ROW-3": { + "type": "ROW", + "children": ["CHART-recent-documents-timeseries", "CHART-document-type-breakdown"] + }, + "ROW-4": { + "type": "ROW", + "children": ["CHART-latest-prices-table"] + }, + "ROW-5": { + "type": "ROW", + "children": ["CHART-recent-documents-table"] + } + }, + "metadata": { + "refresh_frequency": 300, + "default_filters": "{}", + "color_scheme": "supersetColors" + }, + "charts": [ + { + "slice_name": "Tracked Symbols", + "viz_type": "big_number_total", + "description": "Count of distinct symbols with documents in the last 30 days", + "datasource_type": "trino", + "query": "SELECT COUNT(DISTINCT ticker) AS tracked_symbols FROM lakehouse.stonks.documents WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Total Documents (30d)", + "viz_type": "big_number_total", + "description": "Total documents ingested in the last 30 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS total_documents FROM lakehouse.stonks.documents WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Total Extractions (30d)", + "viz_type": "big_number_total", + "description": "Total AI extractions completed in the last 30 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS total_extractions FROM lakehouse.stonks.document_extractions WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY" + }, + { + "slice_name": "Active Signals (7d)", + "viz_type": "big_number_total", + "description": "Trade signals generated in the last 7 days", + "datasource_type": "trino", + "query": "SELECT COUNT(*) AS active_signals FROM lakehouse.stonks.trade_signals WHERE dt >= CURRENT_DATE - INTERVAL '7' DAY" + }, + { + "slice_name": "Company Summary", + "viz_type": "table", + "description": "Per-symbol summary with document counts, extraction counts, latest signal, and latest price", + "datasource_type": "trino", + "query": "SELECT d.ticker, COUNT(DISTINCT d.document_id) AS documents_30d, COUNT(DISTINCT de.document_id) AS extractions_30d, MAX(d.published_at) AS latest_document_at, MAX(ts.generated_at) AS latest_signal_at, MAX(ts.trend_direction) AS latest_trend, MAX(mb.close_price) AS latest_close FROM lakehouse.stonks.documents d LEFT JOIN lakehouse.stonks.document_extractions de ON d.ticker = de.ticker AND de.dt >= CURRENT_DATE - INTERVAL '30' DAY LEFT JOIN lakehouse.stonks.trade_signals ts ON d.ticker = ts.ticker AND ts.dt = (SELECT MAX(dt) FROM lakehouse.stonks.trade_signals WHERE ticker = d.ticker) LEFT JOIN lakehouse.stonks.market_bars mb ON d.ticker = mb.ticker AND mb.dt = (SELECT MAX(dt) FROM lakehouse.stonks.market_bars WHERE ticker = d.ticker) WHERE d.dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY d.ticker ORDER BY d.ticker" + }, + { + "slice_name": "Documents Ingested Over Time", + "viz_type": "echarts_timeseries_bar", + "description": "Daily document ingestion counts by source type over the last 30 days", + "datasource_type": "trino", + "query": "SELECT dt AS bucket, source_type, COUNT(*) AS doc_count FROM lakehouse.stonks.documents WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY dt, source_type ORDER BY dt", + "params": { + "x_axis": "bucket", + "metrics": ["doc_count"], + "groupby": ["source_type"], + "time_grain_sqla": "P1D" + } + }, + { + "slice_name": "Document Type Breakdown", + "viz_type": "pie", + "description": "Distribution of documents by type in the last 30 days", + "datasource_type": "trino", + "query": "SELECT document_type, COUNT(*) AS count FROM lakehouse.stonks.documents WHERE dt >= CURRENT_DATE - INTERVAL '30' DAY GROUP BY document_type ORDER BY count DESC" + }, + { + "slice_name": "Latest Prices by Symbol", + "viz_type": "table", + "description": "Most recent closing prices and volume for each tracked symbol", + "datasource_type": "trino", + "query": "SELECT mb.ticker, mb.close_price, mb.open_price, mb.high_price, mb.low_price, mb.volume, mb.vwap, mb.bar_timestamp FROM lakehouse.stonks.market_bars mb INNER JOIN (SELECT ticker, MAX(bar_timestamp) AS max_ts FROM lakehouse.stonks.market_bars GROUP BY ticker) latest ON mb.ticker = latest.ticker AND mb.bar_timestamp = latest.max_ts ORDER BY mb.ticker" + }, + { + "slice_name": "Recent Documents", + "viz_type": "table", + "description": "Most recently ingested documents across all symbols", + "datasource_type": "trino", + "query": "SELECT ticker, document_type, source_type, title, publisher, published_at, retrieved_at, confidence FROM lakehouse.stonks.documents WHERE dt >= CURRENT_DATE - INTERVAL '7' DAY ORDER BY retrieved_at DESC LIMIT 50" + } + ] +} diff --git a/docker-compose.yml b/docker-compose.yml index 20c47ed..31ef001 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -72,6 +72,9 @@ services: image: trinodb/trino:latest ports: - "8080:8080" + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin volumes: - ./infra/trino/catalog:/etc/trino/catalog depends_on: @@ -83,11 +86,14 @@ services: environment: SERVICE_NAME: metastore DB_DRIVER: derby - SERVICE_OPTS: "-Djavax.jdo.option.ConnectionURL=jdbc:derby:/opt/hive/data/metastore_db;create=true" ports: - "9083:9083" volumes: - hive_data:/opt/hive/data + - ./infra/hive/core-site.xml:/opt/hive/conf/core-site.xml:ro + - ./infra/hive/metastore-site.xml:/opt/hive/conf/metastore-site.xml:ro + depends_on: + - minio superset: image: apache/superset:latest diff --git a/infra/hive/core-site.xml b/infra/hive/core-site.xml new file mode 100644 index 0000000..b1ca0f4 --- /dev/null +++ b/infra/hive/core-site.xml @@ -0,0 +1,27 @@ + + + + fs.s3a.endpoint + http://minio:9000 + + + fs.s3a.access.key + minioadmin + + + fs.s3a.secret.key + minioadmin + + + fs.s3a.path.style.access + true + + + fs.s3a.impl + org.apache.hadoop.fs.s3a.S3AFileSystem + + + fs.s3a.connection.ssl.enabled + false + + diff --git a/infra/hive/metastore-site.xml b/infra/hive/metastore-site.xml new file mode 100644 index 0000000..59d2aa8 --- /dev/null +++ b/infra/hive/metastore-site.xml @@ -0,0 +1,27 @@ + + + + metastore.thrift.uris + thrift://0.0.0.0:9083 + + + metastore.task.threads.always + org.apache.hadoop.hive.metastore.events.EventCleanerTask + + + metastore.expression.proxy + org.apache.hadoop.hive.metastore.DefaultPartitionExpressionProxy + + + javax.jdo.option.ConnectionDriverName + org.apache.derby.jdbc.EmbeddedDriver + + + javax.jdo.option.ConnectionURL + jdbc:derby:/opt/hive/data/metastore_db;create=true + + + metastore.warehouse.dir + s3a://stonks-lakehouse/warehouse + + diff --git a/infra/k8s/aggregation-worker.yaml b/infra/k8s/aggregation-worker.yaml index 4374298..06a4d58 100644 --- a/infra/k8s/aggregation-worker.yaml +++ b/infra/k8s/aggregation-worker.yaml @@ -6,6 +6,7 @@ metadata: labels: app: aggregation-worker app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: processing spec: replicas: 1 selector: @@ -15,16 +16,30 @@ spec: metadata: labels: app: aggregation-worker + stonks-oracle/tier: processing spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: aggregation-worker image: ghcr.io/celesrenata/stonks-oracle/aggregation:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 100m @@ -32,3 +47,10 @@ spec: limits: cpu: 500m memory: 256Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/broker-adapter.yaml b/infra/k8s/broker-adapter.yaml index 043dd2e..0c48eb1 100644 --- a/infra/k8s/broker-adapter.yaml +++ b/infra/k8s/broker-adapter.yaml @@ -6,6 +6,7 @@ metadata: labels: app: broker-adapter app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: trading spec: replicas: 1 selector: @@ -15,16 +16,32 @@ spec: metadata: labels: app: broker-adapter + stonks-oracle/tier: trading spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: broker-adapter image: ghcr.io/celesrenata/stonks-oracle/broker-adapter:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets + - secretRef: + name: stonks-broker-secrets resources: requests: cpu: 50m @@ -32,3 +49,10 @@ spec: limits: cpu: 200m memory: 128Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/configmap.yaml b/infra/k8s/configmap.yaml index eeddc7e..7613c90 100644 --- a/infra/k8s/configmap.yaml +++ b/infra/k8s/configmap.yaml @@ -25,15 +25,48 @@ data: OLLAMA_BASE_URL: "http://ollama.ollama-service.svc.cluster.local:11434" OLLAMA_MODEL: "llama3.1:8b" OLLAMA_TIMEOUT: "120" + OLLAMA_MAX_RETRIES: "2" + OLLAMA_RETRY_BASE_DELAY: "1.0" + OLLAMA_RETRY_MAX_DELAY: "10.0" + OLLAMA_RETRY_BACKOFF_MULTIPLIER: "2.0" # Trino — deployed in stonks-oracle namespace TRINO_HOST: "trino.stonks-oracle.svc.cluster.local" TRINO_PORT: "8080" TRINO_CATALOG: "lakehouse" TRINO_SCHEMA: "stonks" + TRINO_ICEBERG_CATALOG: "iceberg" # Broker BROKER_MODE: "paper" + BROKER_PROVIDER: "alpaca" + + # Market Data + MARKET_DATA_BASE_URL: "https://api.polygon.io" + MARKET_DATA_PROVIDER: "polygon" + + # Retention (days per bucket class) + RETENTION_RAW_MARKET_DAYS: "90" + RETENTION_RAW_NEWS_DAYS: "180" + RETENTION_RAW_FILINGS_DAYS: "365" + RETENTION_NORMALIZED_DAYS: "180" + RETENTION_LLM_PROMPTS_DAYS: "365" + RETENTION_LLM_RESULTS_DAYS: "365" + RETENTION_LAKEHOUSE_DAYS: "730" + RETENTION_AUDIT_DAYS: "730" + RETENTION_CLEANUP_INTERVAL_HOURS: "24" + RETENTION_BATCH_SIZE: "1000" # General LOG_LEVEL: "INFO" + JSON_LOGS: "true" + + # Alerting thresholds + ALERT_SOURCE_FAILURE_THRESHOLD: "3" + ALERT_SOURCE_FAILURE_WINDOW_HOURS: "6" + ALERT_SCHEMA_FAILURE_RATE_THRESHOLD: "0.3" + ALERT_SCHEMA_FAILURE_WINDOW_HOURS: "1" + ALERT_LAKE_LAG_THRESHOLD_MINUTES: "60" + ALERT_BROKER_ERROR_THRESHOLD: "3" + ALERT_BROKER_ERROR_WINDOW_HOURS: "1" + ALERT_CHECK_INTERVAL_SECONDS: "120" diff --git a/infra/k8s/extractor-worker.yaml b/infra/k8s/extractor-worker.yaml index 2b0a066..09d76f6 100644 --- a/infra/k8s/extractor-worker.yaml +++ b/infra/k8s/extractor-worker.yaml @@ -6,6 +6,7 @@ metadata: labels: app: extractor-worker app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: processing spec: replicas: 1 selector: @@ -15,16 +16,30 @@ spec: metadata: labels: app: extractor-worker + stonks-oracle/tier: processing spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: extractor-worker image: ghcr.io/celesrenata/stonks-oracle/extractor:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 200m @@ -32,3 +47,10 @@ spec: limits: cpu: "1" memory: 512Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/hive-metastore.yaml b/infra/k8s/hive-metastore.yaml index 57ec740..0866b75 100644 --- a/infra/k8s/hive-metastore.yaml +++ b/infra/k8s/hive-metastore.yaml @@ -6,6 +6,7 @@ metadata: labels: app: hive-metastore app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: analytics spec: replicas: 1 selector: @@ -15,22 +16,121 @@ spec: metadata: labels: app: hive-metastore + stonks-oracle/tier: analytics spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault + initContainers: + - name: hive-config-init + image: busybox:1.36 + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: ["ALL"] + command: ["sh", "-c"] + args: + - | + cat > /hive-config/core-site.xml < + + + fs.s3a.endpoint + http://minio.minio-service.svc.cluster.local:80 + + + fs.s3a.access.key + ${MINIO_ACCESS_KEY} + + + fs.s3a.secret.key + ${MINIO_SECRET_KEY} + + + fs.s3a.path.style.access + true + + + fs.s3a.impl + org.apache.hadoop.fs.s3a.S3AFileSystem + + + fs.s3a.connection.ssl.enabled + false + + + EOF + cat > /hive-config/metastore-site.xml < + + + metastore.thrift.uris + thrift://0.0.0.0:9083 + + + metastore.task.threads.always + org.apache.hadoop.hive.metastore.events.EventCleanerTask + + + metastore.expression.proxy + org.apache.hadoop.hive.metastore.DefaultPartitionExpressionProxy + + + javax.jdo.option.ConnectionDriverName + org.apache.derby.jdbc.EmbeddedDriver + + + javax.jdo.option.ConnectionURL + jdbc:derby:/opt/hive/data/metastore_db;create=true + + + metastore.warehouse.dir + s3a://stonks-lakehouse/warehouse + + + EOF + env: + - name: MINIO_ACCESS_KEY + valueFrom: + secretKeyRef: + name: stonks-core-secrets + key: MINIO_ACCESS_KEY + - name: MINIO_SECRET_KEY + valueFrom: + secretKeyRef: + name: stonks-core-secrets + key: MINIO_SECRET_KEY + volumeMounts: + - name: hive-config + mountPath: /hive-config containers: - name: hive-metastore image: apache/hive:4.0.0 ports: - containerPort: 9083 + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: ["ALL"] env: - name: SERVICE_NAME value: metastore - name: DB_DRIVER value: derby - - name: SERVICE_OPTS - value: "-Djavax.jdo.option.ConnectionURL=jdbc:derby:/opt/hive/data/metastore_db;create=true" volumeMounts: - name: hive-data mountPath: /opt/hive/data + - name: hive-config + mountPath: /opt/hive/conf/core-site.xml + subPath: core-site.xml + - name: hive-config + mountPath: /opt/hive/conf/metastore-site.xml + subPath: metastore-site.xml resources: requests: cpu: 200m @@ -42,6 +142,8 @@ spec: - name: hive-data persistentVolumeClaim: claimName: hive-metastore-data + - name: hive-config + emptyDir: {} --- apiVersion: v1 kind: Service diff --git a/infra/k8s/ingestion-worker.yaml b/infra/k8s/ingestion-worker.yaml index 26341ca..555b273 100644 --- a/infra/k8s/ingestion-worker.yaml +++ b/infra/k8s/ingestion-worker.yaml @@ -6,6 +6,7 @@ metadata: labels: app: ingestion-worker app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: ingestion spec: replicas: 2 selector: @@ -15,16 +16,32 @@ spec: metadata: labels: app: ingestion-worker + stonks-oracle/tier: ingestion spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: ingestion-worker image: ghcr.io/celesrenata/stonks-oracle/ingestion:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets + - secretRef: + name: stonks-market-secrets resources: requests: cpu: 100m @@ -32,3 +49,10 @@ spec: limits: cpu: 500m memory: 256Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/lake-publisher.yaml b/infra/k8s/lake-publisher.yaml index 9446823..78189b1 100644 --- a/infra/k8s/lake-publisher.yaml +++ b/infra/k8s/lake-publisher.yaml @@ -6,6 +6,7 @@ metadata: labels: app: lake-publisher app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: analytics spec: replicas: 1 selector: @@ -15,16 +16,30 @@ spec: metadata: labels: app: lake-publisher + stonks-oracle/tier: analytics spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: lake-publisher image: ghcr.io/celesrenata/stonks-oracle/lake-publisher:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 100m @@ -32,3 +47,10 @@ spec: limits: cpu: 500m memory: 256Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/namespace.yaml b/infra/k8s/namespace.yaml index 79991b4..4168eaa 100644 --- a/infra/k8s/namespace.yaml +++ b/infra/k8s/namespace.yaml @@ -4,3 +4,4 @@ metadata: name: stonks-oracle labels: app.kubernetes.io/part-of: stonks-oracle + kubernetes.io/metadata.name: stonks-oracle diff --git a/infra/k8s/network-policies.yaml b/infra/k8s/network-policies.yaml new file mode 100644 index 0000000..089fb43 --- /dev/null +++ b/infra/k8s/network-policies.yaml @@ -0,0 +1,173 @@ +## +## Stonks Oracle — Network Policies +## +## Default-deny ingress for the namespace, then allow only the +## traffic patterns each component actually needs. +## +## Requirements: 8.2 (trading isolation), 12.1 (observability) +## + +# ── Default deny all ingress in the namespace ────────────────────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: default-deny-ingress + namespace: stonks-oracle +spec: + podSelector: {} + policyTypes: + - Ingress +--- +# ── Query API: accept from Traefik ingress only ─────────────────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-query-api-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: query-api + policyTypes: + - Ingress + ingress: + - from: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: kube-system + ports: + - protocol: TCP + port: 8000 +--- +# ── Symbol Registry API: accept from Traefik ingress only ───────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-symbol-registry-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: symbol-registry-api + policyTypes: + - Ingress + ingress: + - from: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: kube-system + ports: + - protocol: TCP + port: 8000 +--- +# ── Risk Engine: accept from broker-adapter only ─────────────────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-risk-engine-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: risk-engine + policyTypes: + - Ingress + ingress: + - from: + - podSelector: + matchLabels: + app: broker-adapter + - podSelector: + matchLabels: + app: query-api + ports: + - protocol: TCP + port: 8000 +--- +# ── Superset: accept from Traefik ingress only ──────────────────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-superset-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: superset + policyTypes: + - Ingress + ingress: + - from: + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: kube-system + ports: + - protocol: TCP + port: 8088 +--- +# ── Trino: accept from Superset and query-api ───────────────────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-trino-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: trino + policyTypes: + - Ingress + ingress: + - from: + - podSelector: + matchLabels: + app: superset + - podSelector: + matchLabels: + app: query-api + - namespaceSelector: + matchLabels: + kubernetes.io/metadata.name: kube-system + ports: + - protocol: TCP + port: 8080 +--- +# ── Hive Metastore: accept from Trino and lake-publisher ────────────── +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: allow-hive-metastore-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: hive-metastore + policyTypes: + - Ingress + ingress: + - from: + - podSelector: + matchLabels: + app: trino + - podSelector: + matchLabels: + app: lake-publisher + ports: + - protocol: TCP + port: 9083 +--- +# ── Broker adapter: isolated — no inbound from other pods ────────────── +# The broker-adapter only makes outbound calls to the broker API +# and reads from Redis queues. No pod needs to call into it. +apiVersion: networking.k8s.io/v1 +kind: NetworkPolicy +metadata: + name: deny-broker-adapter-ingress + namespace: stonks-oracle +spec: + podSelector: + matchLabels: + app: broker-adapter + policyTypes: + - Ingress + ingress: [] diff --git a/infra/k8s/parser-worker.yaml b/infra/k8s/parser-worker.yaml index cea1b68..3818429 100644 --- a/infra/k8s/parser-worker.yaml +++ b/infra/k8s/parser-worker.yaml @@ -6,6 +6,7 @@ metadata: labels: app: parser-worker app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: processing spec: replicas: 2 selector: @@ -15,16 +16,30 @@ spec: metadata: labels: app: parser-worker + stonks-oracle/tier: processing spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: parser-worker image: ghcr.io/celesrenata/stonks-oracle/parser:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 100m @@ -32,3 +47,10 @@ spec: limits: cpu: 500m memory: 256Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/query-api.yaml b/infra/k8s/query-api.yaml index 9ff669d..aea7adf 100644 --- a/infra/k8s/query-api.yaml +++ b/infra/k8s/query-api.yaml @@ -6,6 +6,7 @@ metadata: labels: app: query-api app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: api spec: replicas: 1 selector: @@ -15,18 +16,32 @@ spec: metadata: labels: app: query-api + stonks-oracle/tier: api spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: query-api image: ghcr.io/celesrenata/stonks-oracle/query-api:latest imagePullPolicy: Always ports: - containerPort: 8000 + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 100m @@ -40,6 +55,13 @@ spec: port: 8000 initialDelaySeconds: 5 periodSeconds: 10 + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi --- apiVersion: v1 kind: Service diff --git a/infra/k8s/recommendation-worker.yaml b/infra/k8s/recommendation-worker.yaml index 893c8c1..15d16d7 100644 --- a/infra/k8s/recommendation-worker.yaml +++ b/infra/k8s/recommendation-worker.yaml @@ -6,6 +6,7 @@ metadata: labels: app: recommendation-worker app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: processing spec: replicas: 1 selector: @@ -15,16 +16,30 @@ spec: metadata: labels: app: recommendation-worker + stonks-oracle/tier: processing spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: recommendation-worker image: ghcr.io/celesrenata/stonks-oracle/recommendation:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 100m @@ -32,3 +47,10 @@ spec: limits: cpu: 500m memory: 256Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/risk-engine.yaml b/infra/k8s/risk-engine.yaml index a95e1a5..1a979fc 100644 --- a/infra/k8s/risk-engine.yaml +++ b/infra/k8s/risk-engine.yaml @@ -6,6 +6,7 @@ metadata: labels: app: risk-engine app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: trading spec: replicas: 1 selector: @@ -15,18 +16,34 @@ spec: metadata: labels: app: risk-engine + stonks-oracle/tier: trading spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: risk-engine image: ghcr.io/celesrenata/stonks-oracle/risk:latest imagePullPolicy: Always ports: - containerPort: 8000 + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets + - secretRef: + name: stonks-broker-secrets resources: requests: cpu: 100m @@ -34,6 +51,13 @@ spec: limits: cpu: 500m memory: 256Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi --- apiVersion: v1 kind: Service diff --git a/infra/k8s/scheduler.yaml b/infra/k8s/scheduler.yaml index 4e47b6a..b8a2d72 100644 --- a/infra/k8s/scheduler.yaml +++ b/infra/k8s/scheduler.yaml @@ -6,6 +6,7 @@ metadata: labels: app: scheduler app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: orchestration spec: replicas: 1 selector: @@ -15,16 +16,30 @@ spec: metadata: labels: app: scheduler + stonks-oracle/tier: orchestration spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: scheduler image: ghcr.io/celesrenata/stonks-oracle/scheduler:latest imagePullPolicy: Always + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 50m @@ -32,3 +47,10 @@ spec: limits: cpu: 200m memory: 128Mi + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi diff --git a/infra/k8s/secrets.yaml b/infra/k8s/secrets.yaml index dd40921..4fc0a11 100644 --- a/infra/k8s/secrets.yaml +++ b/infra/k8s/secrets.yaml @@ -1,17 +1,63 @@ +## +## Stonks Oracle — Scoped Secrets +## +## Secrets are split by concern so that only the services that need +## broker or market-data credentials actually receive them. +## Replace placeholder values before deploying. +## +## Requirements: 8.2 (broker credential isolation) +## + +# ── Core infrastructure secrets (DB, object store, cache) ────────────── apiVersion: v1 kind: Secret metadata: - name: stonks-secrets + name: stonks-core-secrets namespace: stonks-oracle labels: app.kubernetes.io/part-of: stonks-oracle type: Opaque stringData: - POSTGRES_PASSWORD: "changeme" - MINIO_ACCESS_KEY: "changeme" - MINIO_SECRET_KEY: "changeme" + POSTGRES_PASSWORD: "REPLACE_ME" + MINIO_ACCESS_KEY: "REPLACE_ME" + MINIO_SECRET_KEY: "REPLACE_ME" REDIS_PASSWORD: "" - BROKER_API_KEY: "" - BROKER_API_SECRET: "" - BROKER_BASE_URL: "" - SUPERSET_SECRET_KEY: "stonks-superset-secret-change-me" +--- +# ── Broker secrets — only for broker-adapter and risk-engine ─────────── +apiVersion: v1 +kind: Secret +metadata: + name: stonks-broker-secrets + namespace: stonks-oracle + labels: + app.kubernetes.io/part-of: stonks-oracle +type: Opaque +stringData: + BROKER_API_KEY: "REPLACE_ME" + BROKER_API_SECRET: "REPLACE_ME" + BROKER_BASE_URL: "https://paper-api.alpaca.markets" +--- +# ── Market data secrets — only for ingestion and adapters ────────────── +apiVersion: v1 +kind: Secret +metadata: + name: stonks-market-secrets + namespace: stonks-oracle + labels: + app.kubernetes.io/part-of: stonks-oracle +type: Opaque +stringData: + MARKET_DATA_API_KEY: "REPLACE_ME" +--- +# ── Dashboard secrets — only for Superset ────────────────────────────── +apiVersion: v1 +kind: Secret +metadata: + name: stonks-dashboard-secrets + namespace: stonks-oracle + labels: + app.kubernetes.io/part-of: stonks-oracle +type: Opaque +stringData: + SUPERSET_SECRET_KEY: "REPLACE_ME" + SUPERSET_ADMIN_PASSWORD: "REPLACE_ME" diff --git a/infra/k8s/superset.yaml b/infra/k8s/superset.yaml index 8c1ecf7..0054e86 100644 --- a/infra/k8s/superset.yaml +++ b/infra/k8s/superset.yaml @@ -6,6 +6,7 @@ metadata: labels: app: superset app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: dashboard spec: replicas: 1 selector: @@ -15,22 +16,38 @@ spec: metadata: labels: app: superset + stonks-oracle/tier: dashboard spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: superset image: apache/superset:latest ports: - containerPort: 8088 + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: ["ALL"] env: - name: SUPERSET_SECRET_KEY valueFrom: secretKeyRef: - name: stonks-secrets + name: stonks-dashboard-secrets key: SUPERSET_SECRET_KEY - name: ADMIN_USERNAME value: admin - name: ADMIN_PASSWORD - value: admin + valueFrom: + secretKeyRef: + name: stonks-dashboard-secrets + key: SUPERSET_ADMIN_PASSWORD - name: ADMIN_EMAIL value: admin@stonks.local volumeMounts: @@ -94,12 +111,39 @@ data: import os SECRET_KEY = os.getenv("SUPERSET_SECRET_KEY", "stonks-dev-secret-key-change-me") SQLALCHEMY_DATABASE_URI = "trino://trino@trino.stonks-oracle.svc.cluster.local:8080/lakehouse/stonks" + # Additional database connections available in Superset UI: + # Hive catalog: trino://trino@trino.stonks-oracle.svc.cluster.local:8080/lakehouse/stonks + # Iceberg catalog: trino://trino@trino.stonks-oracle.svc.cluster.local:8080/iceberg/stonks FEATURE_FLAGS = {"ENABLE_TEMPLATE_PROCESSING": True} CACHE_CONFIG = { "CACHE_TYPE": "RedisCache", "CACHE_DEFAULT_TIMEOUT": 300, "CACHE_KEY_PREFIX": "superset_", - "CACHE_REDIS_HOST": os.getenv("REDIS_HOST", "redis.redis-service.svc.cluster.local"), + "CACHE_REDIS_HOST": os.getenv("REDIS_HOST", "redis-master.redis-service.svc.cluster.local"), "CACHE_REDIS_PORT": int(os.getenv("REDIS_PORT", "6379")), "CACHE_REDIS_DB": 1, } + + # --- Security hardening --- + # Disable public user role (require login) + PUBLIC_ROLE_LIKE = None + # Session cookie security + SESSION_COOKIE_HTTPONLY = True + SESSION_COOKIE_SECURE = True + SESSION_COOKIE_SAMESITE = "Lax" + # Talisman CSP headers + TALISMAN_ENABLED = True + TALISMAN_CONFIG = { + "content_security_policy": { + "default-src": ["'self'"], + "img-src": ["'self'", "data:"], + "style-src": ["'self'", "'unsafe-inline'"], + "script-src": ["'self'", "'unsafe-inline'", "'unsafe-eval'"], + }, + "force_https": False, # TLS terminated at ingress + } + # Prevent Superset from allowing arbitrary SQL database connections + PREVENT_UNSAFE_DB_CONNECTIONS = True + # Row limit for queries + ROW_LIMIT = 50000 + SQL_MAX_ROW = 100000 diff --git a/infra/k8s/symbol-registry.yaml b/infra/k8s/symbol-registry.yaml index 24735ab..aeed4a2 100644 --- a/infra/k8s/symbol-registry.yaml +++ b/infra/k8s/symbol-registry.yaml @@ -6,6 +6,7 @@ metadata: labels: app: symbol-registry-api app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: api spec: replicas: 1 selector: @@ -15,18 +16,32 @@ spec: metadata: labels: app: symbol-registry-api + stonks-oracle/tier: api spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault containers: - name: symbol-registry-api image: ghcr.io/celesrenata/stonks-oracle/symbol-registry:latest imagePullPolicy: Always ports: - containerPort: 8000 + securityContext: + allowPrivilegeEscalation: false + readOnlyRootFilesystem: true + capabilities: + drop: ["ALL"] envFrom: - configMapRef: name: stonks-config - secretRef: - name: stonks-secrets + name: stonks-core-secrets resources: requests: cpu: 100m @@ -46,6 +61,13 @@ spec: port: 8000 initialDelaySeconds: 10 periodSeconds: 30 + volumeMounts: + - name: tmp + mountPath: /tmp + volumes: + - name: tmp + emptyDir: + sizeLimit: 10Mi --- apiVersion: v1 kind: Service diff --git a/infra/k8s/trino.yaml b/infra/k8s/trino.yaml index 5a66c4f..c8173af 100644 --- a/infra/k8s/trino.yaml +++ b/infra/k8s/trino.yaml @@ -6,6 +6,7 @@ metadata: labels: app: trino app.kubernetes.io/part-of: stonks-oracle + stonks-oracle/tier: analytics spec: replicas: 1 selector: @@ -15,12 +16,73 @@ spec: metadata: labels: app: trino + stonks-oracle/tier: analytics spec: + automountServiceAccountToken: false + securityContext: + runAsNonRoot: true + runAsUser: 1000 + runAsGroup: 1000 + fsGroup: 1000 + seccompProfile: + type: RuntimeDefault + initContainers: + - name: catalog-init + image: busybox:1.36 + securityContext: + allowPrivilegeEscalation: false + capabilities: + drop: ["ALL"] + command: ["sh", "-c"] + args: + - | + cat > /catalog/iceberg.properties < /catalog/lakehouse.properties <>'recommendation_id', data->>'order_id') +CREATE INDEX IF NOT EXISTS idx_audit_events_data_gin + ON audit_events USING gin (data); + +-- Index for chronological audit trail queries by entity +CREATE INDEX IF NOT EXISTS idx_audit_events_entity_created + ON audit_events (entity_id, created_at ASC); + +-- Index for filtering by event_type + entity_type +CREATE INDEX IF NOT EXISTS idx_audit_events_type_entity + ON audit_events (event_type, entity_type); diff --git a/infra/migrations/013_operator_approval_workflow.sql b/infra/migrations/013_operator_approval_workflow.sql new file mode 100644 index 0000000..a3377bc --- /dev/null +++ b/infra/migrations/013_operator_approval_workflow.sql @@ -0,0 +1,29 @@ +-- Stonks Oracle - Operator approval workflow for live trading mode +-- Tracks pending, approved, rejected, and expired approval requests +-- for orders that require operator sign-off before broker submission. +-- Requirements: 8.2 + +CREATE TABLE operator_approvals ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + order_job JSONB NOT NULL DEFAULT '{}', + recommendation_id UUID REFERENCES recommendations(id), + ticker VARCHAR(20) NOT NULL, + side VARCHAR(10) NOT NULL DEFAULT 'buy', + quantity NUMERIC NOT NULL DEFAULT 0, + estimated_value NUMERIC NOT NULL DEFAULT 0, + status VARCHAR(20) NOT NULL DEFAULT 'pending', + risk_evaluation_id UUID, + requested_by VARCHAR(200) NOT NULL DEFAULT 'system', + reviewed_by VARCHAR(200), + review_note TEXT, + expires_at TIMESTAMPTZ NOT NULL, + requested_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + reviewed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_operator_approvals_status ON operator_approvals(status); +CREATE INDEX idx_operator_approvals_ticker ON operator_approvals(ticker); +CREATE INDEX idx_operator_approvals_expires ON operator_approvals(expires_at) + WHERE status = 'pending'; diff --git a/infra/migrations/014_retention_policies.sql b/infra/migrations/014_retention_policies.sql new file mode 100644 index 0000000..eaec6af --- /dev/null +++ b/infra/migrations/014_retention_policies.sql @@ -0,0 +1,43 @@ +-- Stonks Oracle - Data retention and lifecycle policies +-- Tracks per-bucket and per-artifact-class retention rules. +-- Requirements: N3 + +CREATE TABLE retention_policies ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + bucket_name VARCHAR(200) NOT NULL, + artifact_class VARCHAR(100) NOT NULL DEFAULT 'default', + retention_days INTEGER NOT NULL DEFAULT 365, + archive_before_delete BOOLEAN NOT NULL DEFAULT FALSE, + active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(bucket_name, artifact_class) +); + +-- Seed default retention policies per bucket +INSERT INTO retention_policies (bucket_name, artifact_class, retention_days, archive_before_delete) VALUES + ('stonks-raw-market', 'default', 90, FALSE), + ('stonks-raw-news', 'default', 180, FALSE), + ('stonks-raw-filings', 'default', 365, FALSE), + ('stonks-normalized', 'default', 180, FALSE), + ('stonks-llm-prompts', 'default', 365, FALSE), + ('stonks-llm-results', 'default', 365, FALSE), + ('stonks-lakehouse', 'default', 730, FALSE), + ('stonks-audit', 'default', 730, FALSE); + +-- Track retention cleanup runs for observability +CREATE TABLE retention_runs ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + bucket_name VARCHAR(200) NOT NULL, + objects_scanned INTEGER NOT NULL DEFAULT 0, + objects_deleted INTEGER NOT NULL DEFAULT 0, + bytes_freed BIGINT NOT NULL DEFAULT 0, + db_rows_deleted INTEGER NOT NULL DEFAULT 0, + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ, + status VARCHAR(20) NOT NULL DEFAULT 'running', + error_message TEXT +); + +CREATE INDEX idx_retention_runs_bucket ON retention_runs(bucket_name, started_at DESC); +CREATE INDEX idx_retention_runs_status ON retention_runs(status); diff --git a/infra/minio/lifecycle.json b/infra/minio/lifecycle.json index 82a6932..f93e386 100644 --- a/infra/minio/lifecycle.json +++ b/infra/minio/lifecycle.json @@ -1,14 +1,84 @@ { - "Rules": [ - { - "ID": "raw-retention-365d", - "Status": "Enabled", - "Filter": { - "Prefix": "" - }, - "Expiration": { - "Days": 365 - } + "buckets": { + "stonks-raw-market": { + "Rules": [ + { + "ID": "raw-market-retention-90d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 90 } + } + ] + }, + "stonks-raw-news": { + "Rules": [ + { + "ID": "raw-news-retention-180d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 180 } + } + ] + }, + "stonks-raw-filings": { + "Rules": [ + { + "ID": "raw-filings-retention-365d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 365 } + } + ] + }, + "stonks-normalized": { + "Rules": [ + { + "ID": "normalized-retention-180d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 180 } + } + ] + }, + "stonks-llm-prompts": { + "Rules": [ + { + "ID": "llm-prompts-retention-365d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 365 } + } + ] + }, + "stonks-llm-results": { + "Rules": [ + { + "ID": "llm-results-retention-365d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 365 } + } + ] + }, + "stonks-lakehouse": { + "Rules": [ + { + "ID": "lakehouse-retention-730d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 730 } + } + ] + }, + "stonks-audit": { + "Rules": [ + { + "ID": "audit-retention-730d", + "Status": "Enabled", + "Filter": { "Prefix": "" }, + "Expiration": { "Days": 730 } + } + ] } - ] + } } diff --git a/infra/superset/superset_config.py b/infra/superset/superset_config.py index ed16b21..fc13877 100644 --- a/infra/superset/superset_config.py +++ b/infra/superset/superset_config.py @@ -1,10 +1,18 @@ -"""Apache Superset configuration for Stonks Oracle.""" +"""Apache Superset configuration for Stonks Oracle. + +Security hardening applied: +- Session cookies: HttpOnly, Secure, SameSite=Lax +- Talisman CSP headers enabled +- Public role disabled (login required) +- Unsafe DB connections blocked +- Row limits enforced +""" import os -# Superset secret key +# Superset secret key — must be set via SUPERSET_SECRET_KEY env var SECRET_KEY = os.getenv("SUPERSET_SECRET_KEY", "stonks-dev-secret-key-change-me") -# Trino datasource +# Default Trino datasource (Hive catalog for backward compatibility) SQLALCHEMY_DATABASE_URI = "trino://trino@trino:8080/lakehouse/stonks" # Feature flags @@ -12,6 +20,10 @@ FEATURE_FLAGS = { "ENABLE_TEMPLATE_PROCESSING": True, } +# Additional database connections available in Superset UI: +# Hive catalog: trino://trino@trino:8080/lakehouse/stonks +# Iceberg catalog: trino://trino@trino:8080/iceberg/stonks + # Cache config (Redis-backed) CACHE_CONFIG = { "CACHE_TYPE": "RedisCache", @@ -21,3 +33,31 @@ CACHE_CONFIG = { "CACHE_REDIS_PORT": int(os.getenv("REDIS_PORT", "6379")), "CACHE_REDIS_DB": 1, } + +# --- Security hardening --- +# Disable public user role (require login) +PUBLIC_ROLE_LIKE = None + +# Session cookie security +SESSION_COOKIE_HTTPONLY = True +SESSION_COOKIE_SECURE = True +SESSION_COOKIE_SAMESITE = "Lax" + +# Talisman CSP headers +TALISMAN_ENABLED = True +TALISMAN_CONFIG = { + "content_security_policy": { + "default-src": ["'self'"], + "img-src": ["'self'", "data:"], + "style-src": ["'self'", "'unsafe-inline'"], + "script-src": ["'self'", "'unsafe-inline'", "'unsafe-eval'"], + }, + "force_https": False, # TLS terminated at ingress +} + +# Prevent Superset from allowing arbitrary SQL database connections +PREVENT_UNSAFE_DB_CONNECTIONS = True + +# Row limit for queries +ROW_LIMIT = 50000 +SQL_MAX_ROW = 100000 diff --git a/infra/trino/catalog/iceberg.properties b/infra/trino/catalog/iceberg.properties index 219ab92..10faec8 100644 --- a/infra/trino/catalog/iceberg.properties +++ b/infra/trino/catalog/iceberg.properties @@ -5,3 +5,8 @@ hive.s3.endpoint=http://minio:9000 hive.s3.path-style-access=true hive.s3.aws-access-key=minioadmin hive.s3.aws-secret-key=minioadmin +fs.native-s3.enabled=true +s3.endpoint=http://minio:9000 +s3.path-style-access=true +s3.aws-access-key=minioadmin +s3.aws-secret-key=minioadmin diff --git a/lakehouse/schemas/README.md b/lakehouse/schemas/README.md index 6c8f720..ac037b4 100644 --- a/lakehouse/schemas/README.md +++ b/lakehouse/schemas/README.md @@ -2,15 +2,31 @@ Analytical fact table definitions for MinIO-backed datasets queried via Trino. +All tables use Hive-compatible partition layouts on MinIO (`s3a://stonks-lakehouse/warehouse/`) +and are defined in the `lakehouse.stonks` schema. Parquet is the storage format. + ## Fact Tables -- `lake.market_bars` — OHLCV bar data -- `lake.market_quotes` — quote snapshots -- `lake.company_events` — corporate actions and events -- `lake.documents` — ingested document metadata -- `lake.document_extractions` — AI extraction outputs -- `lake.trade_signals` — aggregated trend signals -- `lake.trade_orders` — order submission records -- `lake.trade_fills` — fill and execution records +- `lake.market_bars` — OHLCV bar data per symbol per interval +- `lake.market_quotes` — bid/ask quote snapshots +- `lake.company_events` — corporate actions, earnings, filings, and issuer events +- `lake.documents` — ingested document metadata (articles, filings, transcripts) +- `lake.document_extractions` — AI extraction outputs per document per company +- `lake.trade_signals` — aggregated trend signals and recommendation actions +- `lake.trade_orders` — order submission records (paper and live) +- `lake.trade_fills` — fill and execution records from broker - `lake.positions_daily` — end-of-day position snapshots -- `lake.pnl_daily` — daily PnL records +- `lake.pnl_daily` — daily PnL records per symbol per account - `lake.prediction_vs_outcome` — prediction accuracy tracking +- `lake.model_performance` — extraction model performance metrics + +## Partitioning +- Most tables partition by `dt` (date) +- `document_extractions`, `prediction_vs_outcome`, and `model_performance` also partition by `model_version` + +## Trino Catalogs +- `lakehouse` catalog (Hive connector) for external Hive-compatible tables +- `iceberg` catalog (Iceberg connector) for managed Iceberg tables + +## Views +Example SQL views for dashboards and ad hoc analysis are in `lakehouse/views/`. +See `lakehouse/views/README.md` for details. diff --git a/lakehouse/schemas/company_events.sql b/lakehouse/schemas/company_events.sql new file mode 100644 index 0000000..c99d2e4 --- /dev/null +++ b/lakehouse/schemas/company_events.sql @@ -0,0 +1,24 @@ +-- Analytical fact table: company_events +-- Corporate actions, earnings, filings, and other issuer events. +-- Partitioned by dt (date) on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/company_events/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 2.3, 9.4, 9.5, 10.1 +-- Design ref: Section 7 (lake.company_events) + +CREATE TABLE IF NOT EXISTS lakehouse.stonks.company_events ( + event_id VARCHAR, + ticker VARCHAR, + event_type VARCHAR, + event_subtype VARCHAR, + title VARCHAR, + description VARCHAR, + source VARCHAR, + source_url VARCHAR, + event_at TIMESTAMP(6) WITH TIME ZONE, + ingested_at TIMESTAMP(6) WITH TIME ZONE, + dt DATE +) WITH ( + format = 'PARQUET', + partitioned_by = ARRAY['dt'], + external_location = 's3a://stonks-lakehouse/warehouse/company_events/' +); diff --git a/lakehouse/schemas/document_extractions.sql b/lakehouse/schemas/document_extractions.sql index 1504e82..4719f8f 100644 --- a/lakehouse/schemas/document_extractions.sql +++ b/lakehouse/schemas/document_extractions.sql @@ -1,16 +1,28 @@ -- Analytical fact table: document_extractions --- Partitioned by dt and model_version on MinIO +-- AI extraction outputs per document per company. +-- Partitioned by dt and model_version on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/document_extractions/dt={yyyy-mm-dd}/model_version={ver}/part-*.parquet +-- Requirements: 5.3, 5.5, 9.4, 9.5, 10.1, 10.4 +-- Design ref: Section 6.3, Section 7 (lake.document_extractions) CREATE TABLE IF NOT EXISTS lakehouse.stonks.document_extractions ( document_id VARCHAR, ticker VARCHAR, + company_name VARCHAR, + relevance DOUBLE, sentiment VARCHAR, impact_score DOUBLE, + impact_horizon VARCHAR, catalyst_type VARCHAR, confidence DOUBLE, novelty_score DOUBLE, + source_credibility DOUBLE, + key_facts VARCHAR, + risks VARCHAR, + macro_themes VARCHAR, model_name VARCHAR, prompt_version VARCHAR, + schema_version VARCHAR, extraction_at TIMESTAMP(6) WITH TIME ZONE, dt DATE, model_version VARCHAR diff --git a/lakehouse/schemas/documents.sql b/lakehouse/schemas/documents.sql index 24e90e3..3cc7bf3 100644 --- a/lakehouse/schemas/documents.sql +++ b/lakehouse/schemas/documents.sql @@ -1,6 +1,9 @@ -- Analytical fact table: documents --- Partitioned by dt and source_type on MinIO --- Path: s3://stonks-lakehouse/warehouse/documents/dt={yyyy-mm-dd}/source_type={type}/part-*.parquet +-- Ingested document metadata for articles, filings, transcripts, and press releases. +-- Partitioned by dt on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/documents/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 3.1, 3.3, 9.4, 9.5, 10.1, 10.4 +-- Design ref: Section 6.2, Section 7 (lake.documents) CREATE TABLE IF NOT EXISTS lakehouse.stonks.documents ( document_id VARCHAR, @@ -9,7 +12,11 @@ CREATE TABLE IF NOT EXISTS lakehouse.stonks.documents ( ticker VARCHAR, publisher VARCHAR, title VARCHAR, + url VARCHAR, + canonical_url VARCHAR, + language VARCHAR, published_at TIMESTAMP(6) WITH TIME ZONE, + retrieved_at TIMESTAMP(6) WITH TIME ZONE, content_hash VARCHAR, confidence DOUBLE, dt DATE diff --git a/lakehouse/schemas/market_bars.sql b/lakehouse/schemas/market_bars.sql index 71f63b5..dcfb757 100644 --- a/lakehouse/schemas/market_bars.sql +++ b/lakehouse/schemas/market_bars.sql @@ -1,6 +1,9 @@ -- Analytical fact table: market_bars --- Partitioned by dt (date) on MinIO +-- OHLCV bar data for tracked symbols. +-- Partitioned by dt (date) on MinIO. -- Path: s3://stonks-lakehouse/warehouse/market_bars/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 2.1, 9.4, 9.5, 10.1 +-- Design ref: Section 7 (lake.market_bars) CREATE TABLE IF NOT EXISTS lakehouse.stonks.market_bars ( ticker VARCHAR, @@ -10,7 +13,9 @@ CREATE TABLE IF NOT EXISTS lakehouse.stonks.market_bars ( close_price DOUBLE, volume BIGINT, vwap DOUBLE, + trade_count BIGINT, bar_timestamp TIMESTAMP(6) WITH TIME ZONE, + bar_interval VARCHAR, source VARCHAR, dt DATE ) WITH ( diff --git a/lakehouse/schemas/market_quotes.sql b/lakehouse/schemas/market_quotes.sql new file mode 100644 index 0000000..04062cf --- /dev/null +++ b/lakehouse/schemas/market_quotes.sql @@ -0,0 +1,23 @@ +-- Analytical fact table: market_quotes +-- Quote snapshots for tracked symbols. +-- Partitioned by dt (date) on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/market_quotes/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 2.1, 9.4, 9.5, 10.1 +-- Design ref: Section 7 (lake.market_quotes) + +CREATE TABLE IF NOT EXISTS lakehouse.stonks.market_quotes ( + ticker VARCHAR, + bid_price DOUBLE, + ask_price DOUBLE, + bid_size BIGINT, + ask_size BIGINT, + last_price DOUBLE, + last_size BIGINT, + source VARCHAR, + quote_at TIMESTAMP(6) WITH TIME ZONE, + dt DATE +) WITH ( + format = 'PARQUET', + partitioned_by = ARRAY['dt'], + external_location = 's3a://stonks-lakehouse/warehouse/market_quotes/' +); diff --git a/lakehouse/schemas/model_performance.sql b/lakehouse/schemas/model_performance.sql new file mode 100644 index 0000000..1bfff6a --- /dev/null +++ b/lakehouse/schemas/model_performance.sql @@ -0,0 +1,33 @@ +-- Analytical fact table: model_performance +-- Tracks extraction model performance for Trino/Superset dashboards. +-- Partitioned by dt and model_name on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/model_performance/dt={yyyy-mm-dd}/model_name={name}/part-*.parquet +-- Requirements: 12.1, 12.2 + +CREATE TABLE IF NOT EXISTS lakehouse.stonks.model_performance ( + document_id VARCHAR, + ticker VARCHAR, + model_name VARCHAR, + prompt_version VARCHAR, + schema_version VARCHAR, + success BOOLEAN, + attempt_count INTEGER, + total_duration_ms INTEGER, + first_attempt_duration_ms INTEGER, + final_attempt_duration_ms INTEGER, + confidence DOUBLE, + validation_status VARCHAR, + validation_error_count INTEGER, + validation_warning_count INTEGER, + retry_count INTEGER, + input_token_estimate INTEGER, + output_token_estimate INTEGER, + company_count INTEGER, + recorded_at TIMESTAMP(6) WITH TIME ZONE, + dt DATE, + model_version VARCHAR +) WITH ( + format = 'PARQUET', + partitioned_by = ARRAY['dt', 'model_version'], + external_location = 's3a://stonks-lakehouse/warehouse/model_performance/' +); diff --git a/lakehouse/schemas/pnl_daily.sql b/lakehouse/schemas/pnl_daily.sql index 805ba78..eb65724 100644 --- a/lakehouse/schemas/pnl_daily.sql +++ b/lakehouse/schemas/pnl_daily.sql @@ -1,12 +1,19 @@ -- Analytical fact table: pnl_daily --- Partitioned by dt on MinIO +-- Daily profit and loss records per symbol per account. +-- Partitioned by dt on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/pnl_daily/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 9.4, 9.5, 10.1, 10.3 +-- Design ref: Section 7 (lake.pnl_daily) CREATE TABLE IF NOT EXISTS lakehouse.stonks.pnl_daily ( ticker VARCHAR, realized_pnl DOUBLE, unrealized_pnl DOUBLE, total_pnl DOUBLE, + fees DOUBLE, + net_pnl DOUBLE, broker_account VARCHAR, + execution_mode VARCHAR, dt DATE ) WITH ( format = 'PARQUET', diff --git a/lakehouse/schemas/positions_daily.sql b/lakehouse/schemas/positions_daily.sql index 17691a1..482ca69 100644 --- a/lakehouse/schemas/positions_daily.sql +++ b/lakehouse/schemas/positions_daily.sql @@ -1,13 +1,19 @@ -- Analytical fact table: positions_daily --- Partitioned by dt on MinIO +-- End-of-day position snapshots. +-- Partitioned by dt on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/positions_daily/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 9.4, 9.5, 10.1, 10.3 +-- Design ref: Section 7 (lake.positions_daily) CREATE TABLE IF NOT EXISTS lakehouse.stonks.positions_daily ( ticker VARCHAR, quantity DOUBLE, avg_entry_price DOUBLE, close_price DOUBLE, + market_value DOUBLE, unrealized_pnl DOUBLE, broker_account VARCHAR, + execution_mode VARCHAR, snapshot_at TIMESTAMP(6) WITH TIME ZONE, dt DATE ) WITH ( diff --git a/lakehouse/schemas/prediction_vs_outcome.sql b/lakehouse/schemas/prediction_vs_outcome.sql index 5a8fd45..f8ae219 100644 --- a/lakehouse/schemas/prediction_vs_outcome.sql +++ b/lakehouse/schemas/prediction_vs_outcome.sql @@ -1,19 +1,24 @@ -- Analytical fact table: prediction_vs_outcome --- Partitioned by dt on MinIO +-- Prediction accuracy tracking: predicted signals vs realized market moves. +-- Partitioned by dt and model_version on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/prediction_vs_outcome/dt={yyyy-mm-dd}/model_version={ver}/part-*.parquet +-- Requirements: 9.4, 9.5, 10.1, 10.3 +-- Design ref: Section 7 (lake.prediction_vs_outcome) CREATE TABLE IF NOT EXISTS lakehouse.stonks.prediction_vs_outcome ( - recommendation_id VARCHAR, - ticker VARCHAR, - predicted_action VARCHAR, + recommendation_id VARCHAR, + ticker VARCHAR, + predicted_action VARCHAR, predicted_confidence DOUBLE, - actual_move_pct DOUBLE, - outcome VARCHAR, - horizon_days INTEGER, - predicted_at TIMESTAMP(6) WITH TIME ZONE, - evaluated_at TIMESTAMP(6) WITH TIME ZONE, - dt DATE + actual_move_pct DOUBLE, + outcome VARCHAR, + horizon_days INTEGER, + predicted_at TIMESTAMP(6) WITH TIME ZONE, + evaluated_at TIMESTAMP(6) WITH TIME ZONE, + model_version VARCHAR, + dt DATE ) WITH ( format = 'PARQUET', - partitioned_by = ARRAY['dt'], + partitioned_by = ARRAY['dt', 'model_version'], external_location = 's3a://stonks-lakehouse/warehouse/prediction_vs_outcome/' ); diff --git a/lakehouse/schemas/trade_fills.sql b/lakehouse/schemas/trade_fills.sql index 5577972..765121b 100644 --- a/lakehouse/schemas/trade_fills.sql +++ b/lakehouse/schemas/trade_fills.sql @@ -1,5 +1,9 @@ -- Analytical fact table: trade_fills --- Partitioned by dt on MinIO +-- Fill and execution records from broker. +-- Partitioned by dt on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/trade_fills/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 9.4, 9.5, 10.1, 10.3 +-- Design ref: Section 7 (lake.trade_fills) CREATE TABLE IF NOT EXISTS lakehouse.stonks.trade_fills ( fill_id VARCHAR, @@ -8,6 +12,7 @@ CREATE TABLE IF NOT EXISTS lakehouse.stonks.trade_fills ( side VARCHAR, fill_price DOUBLE, fill_quantity DOUBLE, + commission DOUBLE, broker_account VARCHAR, filled_at TIMESTAMP(6) WITH TIME ZONE, dt DATE diff --git a/lakehouse/schemas/trade_orders.sql b/lakehouse/schemas/trade_orders.sql index 002bece..836fc25 100644 --- a/lakehouse/schemas/trade_orders.sql +++ b/lakehouse/schemas/trade_orders.sql @@ -1,14 +1,20 @@ -- Analytical fact table: trade_orders --- Partitioned by dt on MinIO +-- Order submission records for paper and live trading. +-- Partitioned by dt on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/trade_orders/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 8.3, 9.4, 9.5, 10.1, 10.3 +-- Design ref: Section 7 (lake.trade_orders) CREATE TABLE IF NOT EXISTS lakehouse.stonks.trade_orders ( order_id VARCHAR, + recommendation_id VARCHAR, ticker VARCHAR, side VARCHAR, order_type VARCHAR, quantity DOUBLE, limit_price DOUBLE, status VARCHAR, + execution_mode VARCHAR, broker_account VARCHAR, submitted_at TIMESTAMP(6) WITH TIME ZONE, dt DATE diff --git a/lakehouse/schemas/trade_signals.sql b/lakehouse/schemas/trade_signals.sql index a76eab0..d8600bd 100644 --- a/lakehouse/schemas/trade_signals.sql +++ b/lakehouse/schemas/trade_signals.sql @@ -1,16 +1,24 @@ -- Analytical fact table: trade_signals --- Partitioned by dt on MinIO +-- Aggregated trend signals and recommendation actions. +-- Partitioned by dt on MinIO. +-- Path: s3://stonks-lakehouse/warehouse/trade_signals/dt={yyyy-mm-dd}/part-*.parquet +-- Requirements: 6.1, 6.2, 6.4, 6.5, 7.1, 9.4, 9.5, 10.1 +-- Design ref: Section 6.4, Section 6.5, Section 7 (lake.trade_signals) CREATE TABLE IF NOT EXISTS lakehouse.stonks.trade_signals ( - signal_id VARCHAR, - ticker VARCHAR, - trend_direction VARCHAR, - trend_strength DOUBLE, - confidence DOUBLE, - action VARCHAR, - time_horizon VARCHAR, - generated_at TIMESTAMP(6) WITH TIME ZONE, - dt DATE + signal_id VARCHAR, + ticker VARCHAR, + trend_direction VARCHAR, + trend_strength DOUBLE, + confidence DOUBLE, + contradiction_score DOUBLE, + dominant_catalysts VARCHAR, + material_risks VARCHAR, + action VARCHAR, + time_horizon VARCHAR, + recommendation_id VARCHAR, + generated_at TIMESTAMP(6) WITH TIME ZONE, + dt DATE ) WITH ( format = 'PARQUET', partitioned_by = ARRAY['dt'], diff --git a/lakehouse/views/README.md b/lakehouse/views/README.md new file mode 100644 index 0000000..f23c0e9 --- /dev/null +++ b/lakehouse/views/README.md @@ -0,0 +1,23 @@ +# Lakehouse Views + +Example SQL views for Trino over MinIO-backed analytical fact tables. + +These views are designed to be created in the `lakehouse.stonks` schema and +can be used directly in Superset dashboards or ad hoc Trino queries. + +## Views + +- `prediction_accuracy` — Joins predicted signals with realized market moves to score prediction quality +- `paper_trade_scorecard` — Aggregates paper trading performance by symbol with win rates and PnL +- `paper_trade_detail` — Per-order paper trade detail with fill prices and realized outcomes +- `signal_hit_rate` — Daily signal accuracy summary across all symbols + +## Usage + +Connect to Trino and run each `.sql` file to create the view: + +```bash +trino --catalog lakehouse --schema stonks < lakehouse/views/prediction_accuracy.sql +``` + +Or paste into the Superset SQL Lab to explore interactively. diff --git a/lakehouse/views/paper_trade_detail.sql b/lakehouse/views/paper_trade_detail.sql new file mode 100644 index 0000000..988fed0 --- /dev/null +++ b/lakehouse/views/paper_trade_detail.sql @@ -0,0 +1,47 @@ +-- View: paper_trade_detail +-- Per-order paper trade detail joining orders, fills, and the originating +-- recommendation's prediction outcome. Useful for drill-down from the scorecard. +-- Requirements: 10.1, 10.3, 10.4 +-- Design ref: Section 9.2 (evidence-to-outcome drill-down) + +CREATE OR REPLACE VIEW lakehouse.stonks.paper_trade_detail AS +SELECT + o.order_id, + o.recommendation_id, + o.ticker, + o.side, + o.order_type, + o.quantity, + o.limit_price, + o.status AS order_status, + o.submitted_at, + f.fill_id, + f.fill_price, + f.fill_quantity, + f.commission, + f.filled_at, + -- Slippage: difference between limit and fill price (buys positive = worse) + CASE + WHEN o.limit_price IS NOT NULL AND o.limit_price > 0 THEN + (f.fill_price - o.limit_price) / o.limit_price * 100 + ELSE NULL + END AS slippage_pct, + -- Link back to prediction outcome + pvo.predicted_action, + pvo.predicted_confidence, + pvo.actual_move_pct, + pvo.outcome AS prediction_outcome, + o.broker_account, + o.dt +FROM + lakehouse.stonks.trade_orders o +LEFT JOIN + lakehouse.stonks.trade_fills f + ON o.order_id = f.order_id + AND o.dt = f.dt +LEFT JOIN + lakehouse.stonks.prediction_vs_outcome pvo + ON o.recommendation_id = pvo.recommendation_id + AND o.dt = pvo.dt +WHERE + o.execution_mode = 'paper'; diff --git a/lakehouse/views/paper_trade_scorecard.sql b/lakehouse/views/paper_trade_scorecard.sql new file mode 100644 index 0000000..4f91090 --- /dev/null +++ b/lakehouse/views/paper_trade_scorecard.sql @@ -0,0 +1,42 @@ +-- View: paper_trade_scorecard +-- Aggregates paper trading performance per symbol with win rates, PnL, and +-- average fill quality. Filters to paper execution mode only. +-- Requirements: 10.1, 10.2, 10.3 +-- Design ref: Section 9.2 (paper trading PnL scorecard) + +CREATE OR REPLACE VIEW lakehouse.stonks.paper_trade_scorecard AS +SELECT + pnl.ticker, + pnl.broker_account, + COUNT(DISTINCT pnl.dt) AS trading_days, + SUM(pnl.realized_pnl) AS total_realized_pnl, + SUM(pnl.unrealized_pnl) AS total_unrealized_pnl, + SUM(pnl.net_pnl) AS total_net_pnl, + SUM(pnl.fees) AS total_fees, + AVG(pnl.net_pnl) AS avg_daily_pnl, + -- Win rate: fraction of days with positive net PnL + CAST( + COUNT(CASE WHEN pnl.net_pnl > 0 THEN 1 END) AS DOUBLE + ) / NULLIF(COUNT(*), 0) AS win_rate, + -- Worst and best single-day PnL + MIN(pnl.net_pnl) AS worst_day_pnl, + MAX(pnl.net_pnl) AS best_day_pnl, + -- Order counts from trade_orders + COUNT(DISTINCT o.order_id) AS total_orders, + COUNT(DISTINCT CASE WHEN o.status = 'filled' THEN o.order_id END) + AS filled_orders, + MIN(pnl.dt) AS first_trade_date, + MAX(pnl.dt) AS last_trade_date +FROM + lakehouse.stonks.pnl_daily pnl +LEFT JOIN + lakehouse.stonks.trade_orders o + ON pnl.ticker = o.ticker + AND pnl.broker_account = o.broker_account + AND pnl.dt = o.dt + AND o.execution_mode = 'paper' +WHERE + pnl.execution_mode = 'paper' +GROUP BY + pnl.ticker, + pnl.broker_account; diff --git a/lakehouse/views/prediction_accuracy.sql b/lakehouse/views/prediction_accuracy.sql new file mode 100644 index 0000000..2dc63cc --- /dev/null +++ b/lakehouse/views/prediction_accuracy.sql @@ -0,0 +1,44 @@ +-- View: prediction_accuracy +-- Joins prediction_vs_outcome with trade_signals and market_bars to provide +-- a comprehensive prediction accuracy scorecard. +-- Requirements: 10.1, 10.2, 10.3, 10.4 +-- Design ref: Section 9.2 (prediction confidence vs realized move) + +CREATE OR REPLACE VIEW lakehouse.stonks.prediction_accuracy AS +SELECT + pvo.recommendation_id, + pvo.ticker, + pvo.predicted_action, + pvo.predicted_confidence, + pvo.actual_move_pct, + pvo.outcome, + pvo.horizon_days, + pvo.predicted_at, + pvo.evaluated_at, + pvo.model_version, + ts.trend_direction, + ts.trend_strength, + ts.contradiction_score, + ts.dominant_catalysts, + -- Confidence bucket for dashboard grouping + CASE + WHEN pvo.predicted_confidence >= 0.8 THEN 'high' + WHEN pvo.predicted_confidence >= 0.5 THEN 'medium' + ELSE 'low' + END AS confidence_bucket, + -- Direction correctness: did the predicted action match the actual move? + CASE + WHEN pvo.predicted_action = 'buy' AND pvo.actual_move_pct > 0 THEN true + WHEN pvo.predicted_action = 'sell' AND pvo.actual_move_pct < 0 THEN true + WHEN pvo.predicted_action IN ('hold', 'watch') THEN NULL + ELSE false + END AS direction_correct, + -- Magnitude of prediction error + ABS(pvo.actual_move_pct) AS abs_move_pct, + pvo.dt +FROM + lakehouse.stonks.prediction_vs_outcome pvo +LEFT JOIN + lakehouse.stonks.trade_signals ts + ON pvo.recommendation_id = ts.recommendation_id + AND pvo.dt = ts.dt; diff --git a/lakehouse/views/signal_hit_rate.sql b/lakehouse/views/signal_hit_rate.sql new file mode 100644 index 0000000..15608f6 --- /dev/null +++ b/lakehouse/views/signal_hit_rate.sql @@ -0,0 +1,31 @@ +-- View: signal_hit_rate +-- Daily summary of signal accuracy across all symbols and model versions. +-- Designed for the Superset prediction accuracy dashboard. +-- Requirements: 10.1, 10.2, 10.3 +-- Design ref: Section 9.2 (prediction confidence vs realized move) + +CREATE OR REPLACE VIEW lakehouse.stonks.signal_hit_rate AS +SELECT + pvo.dt, + pvo.model_version, + COUNT(*) AS total_predictions, + COUNT(CASE WHEN pvo.outcome = 'correct' THEN 1 END) AS correct_predictions, + COUNT(CASE WHEN pvo.outcome = 'incorrect' THEN 1 END) AS incorrect_predictions, + COUNT(CASE WHEN pvo.outcome = 'neutral' THEN 1 END) AS neutral_predictions, + -- Hit rate + CAST( + COUNT(CASE WHEN pvo.outcome = 'correct' THEN 1 END) AS DOUBLE + ) / NULLIF(COUNT(*), 0) AS hit_rate, + -- Average confidence of correct vs incorrect + AVG(CASE WHEN pvo.outcome = 'correct' THEN pvo.predicted_confidence END) + AS avg_confidence_correct, + AVG(CASE WHEN pvo.outcome = 'incorrect' THEN pvo.predicted_confidence END) + AS avg_confidence_incorrect, + -- Average realized move magnitude + AVG(ABS(pvo.actual_move_pct)) AS avg_abs_move_pct, + AVG(pvo.actual_move_pct) AS avg_move_pct +FROM + lakehouse.stonks.prediction_vs_outcome pvo +GROUP BY + pvo.dt, + pvo.model_version; diff --git a/requirements.txt b/requirements.txt index 0f988d5..ca7c876 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,12 @@ pandas>=2.2.0 # Trino trino>=0.330.0 +# Observability +prometheus_client>=0.21.0 + +# YAML parsing (used by K8s security tests) +pyyaml>=6.0.0 + # Testing pytest>=8.0.0 pytest-asyncio>=0.24.0 diff --git a/services/adapters/__init__.py b/services/adapters/__init__.py index 909a10a..b4022bf 100644 --- a/services/adapters/__init__.py +++ b/services/adapters/__init__.py @@ -1 +1,45 @@ # Ingestion Adapters +from .base import AdapterResult, BaseAdapter +from .resilient import ResilientAdapter, RetryConfig, RetryStats, compute_delay +from .broker_adapter import ( + AccountInfo, + AlpacaBrokerAdapter, + BrokerDataAdapter, + OrderEventType, + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, + TradingMode, +) +from .filings_adapter import FilingsDataAdapter, SECEdgarAdapter +from .market_adapter import MarketDataAdapter, PolygonMarketAdapter +from .news_adapter import NewsDataAdapter, PolygonNewsAdapter + +__all__ = [ + "AccountInfo", + "AdapterResult", + "AlpacaBrokerAdapter", + "BaseAdapter", + "BrokerDataAdapter", + "FilingsDataAdapter", + "MarketDataAdapter", + "NewsDataAdapter", + "OrderEventType", + "OrderRequest", + "OrderResponse", + "OrderSide", + "OrderStatus", + "OrderType", + "PolygonMarketAdapter", + "PolygonNewsAdapter", + "PositionInfo", + "ResilientAdapter", + "RetryConfig", + "RetryStats", + "SECEdgarAdapter", + "TradingMode", + "compute_delay", +] diff --git a/services/adapters/base.py b/services/adapters/base.py index 72e604f..6ca6768 100644 --- a/services/adapters/base.py +++ b/services/adapters/base.py @@ -1,29 +1,84 @@ -"""Base adapter interface for all external API integrations.""" +"""Base adapter interface for all external API integrations. + +All ingestion adapters follow the same contract: +1. Fetch external payloads for a given ticker/source config. +2. Return a structured result with raw bytes, parsed items, and metadata. +3. The ingestion worker handles MinIO upload, PostgreSQL metadata, and downstream job emission. + +Requirements: 2.1, 2.2, 2.3, 2.4, 2.5, 3.1, 3.2, 3.3, 3.4 +""" from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any @dataclass class AdapterResult: + """Result of a single adapter fetch operation.""" + source_type: str ticker: str - items: List[Dict[str, Any]] + items: list[dict[str, Any]] raw_payload: bytes content_hash: str fetched_at: datetime - error: Optional[str] = None + error: str | None = None + # HTTP metadata for observability + http_status: int | None = None + response_time_ms: float | None = None + # Additional metadata the adapter wants to pass downstream + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def ok(self) -> bool: + """True if the fetch succeeded without error.""" + return self.error is None and len(self.items) > 0 + + @property + def item_count(self) -> int: + return len(self.items) class BaseAdapter(ABC): - """Interface for all ingestion adapters.""" + """Interface for all ingestion adapters. + + Subclasses implement fetch() for their specific API and source_type() + to identify the adapter class. The ingestion worker orchestrates + persistence and downstream job emission. + """ @abstractmethod - async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult: - """Fetch data for a given ticker using source config.""" + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch data for a given ticker using source config. + + Args: + ticker: The company ticker symbol. + config: Source-specific configuration from the sources table. + + Returns: + AdapterResult with raw payload, parsed items, and metadata. + """ ... @abstractmethod def source_type(self) -> str: + """Return the source type identifier for this adapter (e.g. 'market_api').""" ... + + def bucket_name(self) -> str: + """Return the MinIO bucket name for raw artifact storage. + + Override in subclasses if the bucket differs from the default pattern. + """ + return f"stonks-raw-{self.source_type().replace('_api', '').replace('_', '-')}" + + def artifact_path(self, ticker: str, document_id: str, now: datetime) -> str: + """Build the MinIO object path for a raw artifact. + + Pattern: /{source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/raw.json + """ + return ( + f"{self.source_type()}/{ticker}/" + f"{now.strftime('%Y/%m/%d')}/{document_id}/raw.json" + ) diff --git a/services/adapters/broker_adapter.py b/services/adapters/broker_adapter.py index 8cc0357..fd8ed47 100644 --- a/services/adapters/broker_adapter.py +++ b/services/adapters/broker_adapter.py @@ -1,9 +1,19 @@ -"""Broker API adapter - paper/live trading, orders, positions, balances.""" +"""Broker API adapter interface for paper trading and order events. + +The BrokerDataAdapter is the abstract interface for all broker integrations. +AlpacaBrokerAdapter is the first concrete implementation, targeting the +Alpaca Markets REST API for paper and live trading. + +Requirements: 2.4, 2.5, 8.1, 8.3, 8.5 +""" import hashlib import logging +import time import uuid -from datetime import datetime -from typing import Any, Dict, Optional +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from enum import Enum +from typing import Any import httpx @@ -12,97 +22,584 @@ from .base import AdapterResult, BaseAdapter logger = logging.getLogger("broker_adapter") -class BrokerAdapter(BaseAdapter): - """Broker API adapter supporting paper and live modes.""" +# --- Broker-specific enums --- - def __init__(self, api_key: str = "", api_secret: str = "", base_url: str = "", mode: str = "paper"): - self.api_key = api_key - self.api_secret = api_secret - self.base_url = base_url - self.mode = mode # paper | live + +class OrderSide(str, Enum): + BUY = "buy" + SELL = "sell" + + +class OrderType(str, Enum): + MARKET = "market" + LIMIT = "limit" + STOP = "stop" + STOP_LIMIT = "stop_limit" + + +class OrderStatus(str, Enum): + PENDING = "pending" + SUBMITTED = "submitted" + ACCEPTED = "accepted" + PARTIALLY_FILLED = "partially_filled" + FILLED = "filled" + CANCELLED = "cancelled" + REJECTED = "rejected" + EXPIRED = "expired" + + +class TradingMode(str, Enum): + PAPER = "paper" + LIVE = "live" + + +class OrderEventType(str, Enum): + SUBMITTED = "submitted" + ACCEPTED = "accepted" + REJECTED = "rejected" + FILL = "fill" + PARTIAL_FILL = "partial_fill" + CANCELLED = "cancelled" + EXPIRED = "expired" + + +# --- Data structures --- + + +class OrderRequest: + """Represents an order to be submitted to a broker.""" + + def __init__( + self, + ticker: str, + side: OrderSide, + quantity: float, + order_type: OrderType = OrderType.MARKET, + limit_price: float | None = None, + stop_price: float | None = None, + time_in_force: str = "day", + idempotency_key: str | None = None, + ) -> None: + self.ticker = ticker + self.side = side + self.quantity = quantity + self.order_type = order_type + self.limit_price = limit_price + self.stop_price = stop_price + self.time_in_force = time_in_force + self.idempotency_key = idempotency_key or str(uuid.uuid4()) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a dict for audit/persistence.""" + d: dict[str, Any] = { + "ticker": self.ticker, + "side": self.side.value, + "quantity": self.quantity, + "order_type": self.order_type.value, + "time_in_force": self.time_in_force, + "idempotency_key": self.idempotency_key, + } + if self.limit_price is not None: + d["limit_price"] = self.limit_price + if self.stop_price is not None: + d["stop_price"] = self.stop_price + return d + + +class OrderResponse: + """Represents a broker's response to an order submission.""" + + def __init__( + self, + broker_order_id: str, + status: OrderStatus, + ticker: str, + side: OrderSide, + quantity: float, + filled_quantity: float = 0.0, + filled_avg_price: float | None = None, + submitted_at: datetime | None = None, + raw_response: dict[str, Any] | None = None, + error: str | None = None, + ) -> None: + self.broker_order_id = broker_order_id + self.status = status + self.ticker = ticker + self.side = side + self.quantity = quantity + self.filled_quantity = filled_quantity + self.filled_avg_price = filled_avg_price + self.submitted_at = submitted_at or datetime.now(timezone.utc) + self.raw_response = raw_response or {} + self.error = error + + @property + def ok(self) -> bool: + return self.error is None and self.status not in ( + OrderStatus.REJECTED, + OrderStatus.CANCELLED, + OrderStatus.EXPIRED, + ) + + def to_dict(self) -> dict[str, Any]: + return { + "broker_order_id": self.broker_order_id, + "status": self.status.value, + "ticker": self.ticker, + "side": self.side.value, + "quantity": self.quantity, + "filled_quantity": self.filled_quantity, + "filled_avg_price": self.filled_avg_price, + "submitted_at": self.submitted_at.isoformat(), + "error": self.error, + } + + +class PositionInfo: + """Represents a current position from the broker.""" + + def __init__( + self, + ticker: str, + quantity: float, + avg_entry_price: float, + current_price: float, + unrealized_pnl: float, + market_value: float, + side: str = "long", + ) -> None: + self.ticker = ticker + self.quantity = quantity + self.avg_entry_price = avg_entry_price + self.current_price = current_price + self.unrealized_pnl = unrealized_pnl + self.market_value = market_value + self.side = side + + def to_dict(self) -> dict[str, Any]: + return { + "ticker": self.ticker, + "quantity": self.quantity, + "avg_entry_price": self.avg_entry_price, + "current_price": self.current_price, + "unrealized_pnl": self.unrealized_pnl, + "market_value": self.market_value, + "side": self.side, + } + + +class AccountInfo: + """Represents broker account summary.""" + + def __init__( + self, + account_id: str, + buying_power: float, + cash: float, + portfolio_value: float, + currency: str = "USD", + mode: TradingMode = TradingMode.PAPER, + ) -> None: + self.account_id = account_id + self.buying_power = buying_power + self.cash = cash + self.portfolio_value = portfolio_value + self.currency = currency + self.mode = mode + + def to_dict(self) -> dict[str, Any]: + return { + "account_id": self.account_id, + "buying_power": self.buying_power, + "cash": self.cash, + "portfolio_value": self.portfolio_value, + "currency": self.currency, + "mode": self.mode.value, + } + + +# --- Abstract interface --- + + +class BrokerDataAdapter(BaseAdapter, ABC): + """Abstract interface for broker API integrations. + + Extends BaseAdapter with broker-specific operations: + - submit_order: place an order with idempotency key + - cancel_order: cancel an existing order + - get_order_status: check order state + - get_positions: list current positions + - get_account: retrieve account summary + + All concrete adapters must enforce: + - Idempotent order submission via idempotency_key (Req 8.5) + - Paper/live mode separation (Req 8.1) + - Fail-closed on broker unavailability (Req 8.5) + """ + + def __init__(self, mode: TradingMode = TradingMode.PAPER) -> None: + self._mode = mode + + @property + def mode(self) -> TradingMode: + return self._mode def source_type(self) -> str: return "broker" - def _headers(self) -> Dict[str, str]: + @abstractmethod + async def submit_order(self, order: OrderRequest) -> OrderResponse: + """Submit an order to the broker. + + Must use order.idempotency_key to prevent duplicate submissions. + Must fail closed if the broker is unavailable or returns ambiguous state. + """ + ... + + @abstractmethod + async def cancel_order(self, broker_order_id: str) -> OrderResponse: + """Cancel an existing order by broker order ID.""" + ... + + @abstractmethod + async def get_order_status(self, broker_order_id: str) -> OrderResponse: + """Get the current status of an order.""" + ... + + @abstractmethod + async def get_positions(self) -> list[PositionInfo]: + """Get all current positions.""" + ... + + @abstractmethod + async def get_account(self) -> AccountInfo: + """Get account summary (balance, buying power, etc.).""" + ... + + +# --- Concrete Alpaca implementation --- + + +class AlpacaBrokerAdapter(BrokerDataAdapter): + """Concrete broker adapter for the Alpaca Markets REST API. + + Supports: + - Paper trading via paper-api.alpaca.markets + - Live trading via api.alpaca.markets + - Order submission, cancellation, and status + - Position and account queries + + Config options for fetch(): + endpoint: One of "positions", "orders", "account" (default "positions") + """ + + PAPER_BASE_URL: str = "https://paper-api.alpaca.markets" + LIVE_BASE_URL: str = "https://api.alpaca.markets" + + def __init__( + self, + api_key: str, + api_secret: str, + mode: TradingMode = TradingMode.PAPER, + base_url: str | None = None, + ) -> None: + super().__init__(mode=mode) + self.api_key = api_key + self.api_secret = api_secret + if base_url: + self.base_url = base_url.rstrip("/") + elif mode == TradingMode.LIVE: + self.base_url = self.LIVE_BASE_URL + else: + self.base_url = self.PAPER_BASE_URL + + def _headers(self) -> dict[str, str]: return { - "Authorization": f"Bearer {self.api_key}", + "APCA-API-KEY-ID": self.api_key, + "APCA-API-SECRET-KEY": self.api_secret, "Content-Type": "application/json", } - async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult: - """Fetch positions and recent orders for a ticker.""" + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch positions or recent orders for a ticker from Alpaca. + + This satisfies the BaseAdapter contract for the ingestion pipeline. + The broker adapter uses fetch() to pull position/order snapshots + that get persisted as raw artifacts. + """ + endpoint = config.get("endpoint", "positions") + url = self._build_fetch_url(ticker, endpoint) + async with httpx.AsyncClient(timeout=30) as client: + t0 = time.monotonic() try: - resp = await client.get( - f"{self.base_url}/v2/positions/{ticker}", - headers=self._headers(), - ) + resp = await client.get(url, headers=self._headers()) + elapsed_ms = (time.monotonic() - t0) * 1000 + resp.raise_for_status() + raw = resp.content - data = resp.json() if resp.status_code == 200 else {} + data = resp.json() content_hash = hashlib.sha256(raw).hexdigest() + items = [data] if isinstance(data, dict) else data if isinstance(data, list) else [] return AdapterResult( source_type="broker", ticker=ticker, - items=[data] if data else [], + items=items, raw_payload=raw, content_hash=content_hash, - fetched_at=datetime.utcnow(), + fetched_at=datetime.now(timezone.utc), + http_status=resp.status_code, + response_time_ms=round(elapsed_ms, 1), + metadata={ + "provider": "alpaca", + "mode": self._mode.value, + "endpoint": endpoint, + }, + ) + except httpx.HTTPStatusError as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Alpaca HTTP error for %s: %s", ticker, e) + return self._error_result( + ticker, str(e), elapsed_ms, + http_status=e.response.status_code if e.response else None, + raw=e.response.content if e.response else b"", ) except Exception as e: - logger.error(f"Broker fetch failed for {ticker}: {e}") - return AdapterResult( - source_type="broker", - ticker=ticker, - items=[], - raw_payload=b"", - content_hash="", - fetched_at=datetime.utcnow(), - error=str(e), - ) + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Alpaca fetch failed for %s: %s", ticker, e) + return self._error_result(ticker, str(e), elapsed_ms) - async def submit_order( - self, - ticker: str, - side: str, - qty: float, - order_type: str = "market", - limit_price: Optional[float] = None, - idempotency_key: Optional[str] = None, - ) -> Dict[str, Any]: - """Submit an order to the broker. Returns broker response.""" - if self.mode == "live": - logger.warning("LIVE order submission") + def _build_fetch_url(self, ticker: str, endpoint: str) -> str: + """Build the URL for a fetch operation.""" + if endpoint == "orders": + return f"{self.base_url}/v2/orders?symbols={ticker}&status=all&limit=50" + if endpoint == "account": + return f"{self.base_url}/v2/account" + # Default: positions for ticker + return f"{self.base_url}/v2/positions/{ticker}" - idem_key = idempotency_key or str(uuid.uuid4()) - payload = { - "symbol": ticker, - "qty": str(qty), - "side": side, - "type": order_type, - "time_in_force": "day", + async def submit_order(self, order: OrderRequest) -> OrderResponse: + """Submit an order to Alpaca with idempotency key. + + Fails closed: any network error or ambiguous response returns + a rejected OrderResponse rather than risking duplicate orders. + """ + if self._mode == TradingMode.LIVE: + logger.warning("LIVE order submission: %s %s %s", order.side.value, order.quantity, order.ticker) + + payload: dict[str, Any] = { + "symbol": order.ticker, + "qty": str(order.quantity), + "side": order.side.value, + "type": order.order_type.value, + "time_in_force": order.time_in_force, } - if limit_price and order_type == "limit": - payload["limit_price"] = str(limit_price) + if order.limit_price is not None and order.order_type in (OrderType.LIMIT, OrderType.STOP_LIMIT): + payload["limit_price"] = str(order.limit_price) + if order.stop_price is not None and order.order_type in (OrderType.STOP, OrderType.STOP_LIMIT): + payload["stop_price"] = str(order.stop_price) + + headers = {**self._headers(), "Idempotency-Key": order.idempotency_key} async with httpx.AsyncClient(timeout=30) as client: try: resp = await client.post( f"{self.base_url}/v2/orders", - headers={**self._headers(), "Idempotency-Key": idem_key}, + headers=headers, json=payload, ) resp.raise_for_status() - return resp.json() + data = resp.json() + return self._parse_order_response(data) except httpx.HTTPStatusError as e: - logger.error(f"Order rejected: {e.response.text}") - return {"error": e.response.text, "status": e.response.status_code} + error_body = e.response.text if e.response else "unknown" + logger.error("Order rejected by Alpaca: %s", error_body) + return OrderResponse( + broker_order_id="", + status=OrderStatus.REJECTED, + ticker=order.ticker, + side=order.side, + quantity=order.quantity, + error=f"HTTP {e.response.status_code}: {error_body}" if e.response else str(e), + raw_response={"error": error_body}, + ) except Exception as e: - logger.error(f"Order submission failed: {e}") - return {"error": str(e)} + # Fail closed: treat any unexpected error as rejection + logger.error("Order submission failed (fail-closed): %s", e) + return OrderResponse( + broker_order_id="", + status=OrderStatus.REJECTED, + ticker=order.ticker, + side=order.side, + quantity=order.quantity, + error=f"fail-closed: {e}", + ) - async def get_account(self) -> Dict[str, Any]: + async def cancel_order(self, broker_order_id: str) -> OrderResponse: + """Cancel an order on Alpaca.""" async with httpx.AsyncClient(timeout=30) as client: - resp = await client.get(f"{self.base_url}/v2/account", headers=self._headers()) - return resp.json() + try: + resp = await client.delete( + f"{self.base_url}/v2/orders/{broker_order_id}", + headers=self._headers(), + ) + if resp.status_code == 204: + return OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.CANCELLED, + ticker="", + side=OrderSide.BUY, + quantity=0, + ) + resp.raise_for_status() + data = resp.json() + return self._parse_order_response(data) + except Exception as e: + logger.error("Cancel failed for %s: %s", broker_order_id, e) + return OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.REJECTED, + ticker="", + side=OrderSide.BUY, + quantity=0, + error=str(e), + ) + + async def get_order_status(self, broker_order_id: str) -> OrderResponse: + """Get order status from Alpaca.""" + async with httpx.AsyncClient(timeout=30) as client: + try: + resp = await client.get( + f"{self.base_url}/v2/orders/{broker_order_id}", + headers=self._headers(), + ) + resp.raise_for_status() + data = resp.json() + return self._parse_order_response(data) + except Exception as e: + logger.error("Get order status failed for %s: %s", broker_order_id, e) + return OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.REJECTED, + ticker="", + side=OrderSide.BUY, + quantity=0, + error=str(e), + ) + + async def get_positions(self) -> list[PositionInfo]: + """Get all current positions from Alpaca.""" + async with httpx.AsyncClient(timeout=30) as client: + try: + resp = await client.get( + f"{self.base_url}/v2/positions", + headers=self._headers(), + ) + resp.raise_for_status() + data = resp.json() + if not isinstance(data, list): + return [] + return [self._parse_position(p) for p in data if isinstance(p, dict)] + except Exception as e: + logger.error("Get positions failed: %s", e) + return [] + + async def get_account(self) -> AccountInfo: + """Get account summary from Alpaca.""" + async with httpx.AsyncClient(timeout=30) as client: + try: + resp = await client.get( + f"{self.base_url}/v2/account", + headers=self._headers(), + ) + resp.raise_for_status() + data = resp.json() + return AccountInfo( + account_id=str(data.get("id", "")), + buying_power=float(data.get("buying_power", 0)), + cash=float(data.get("cash", 0)), + portfolio_value=float(data.get("portfolio_value", 0)), + currency=str(data.get("currency", "USD")), + mode=self._mode, + ) + except Exception as e: + logger.error("Get account failed: %s", e) + return AccountInfo( + account_id="", + buying_power=0, + cash=0, + portfolio_value=0, + mode=self._mode, + ) + + def _parse_order_response(self, data: dict[str, Any]) -> OrderResponse: + """Parse an Alpaca order response into an OrderResponse.""" + status_map: dict[str, OrderStatus] = { + "new": OrderStatus.SUBMITTED, + "accepted": OrderStatus.ACCEPTED, + "partially_filled": OrderStatus.PARTIALLY_FILLED, + "filled": OrderStatus.FILLED, + "done_for_day": OrderStatus.FILLED, + "canceled": OrderStatus.CANCELLED, + "expired": OrderStatus.EXPIRED, + "replaced": OrderStatus.SUBMITTED, + "pending_new": OrderStatus.PENDING, + "pending_cancel": OrderStatus.PENDING, + "pending_replace": OrderStatus.PENDING, + "rejected": OrderStatus.REJECTED, + } + raw_status = str(data.get("status", "pending")) + status = status_map.get(raw_status, OrderStatus.PENDING) + + side_str = str(data.get("side", "buy")) + side = OrderSide.SELL if side_str == "sell" else OrderSide.BUY + + filled_qty = float(data.get("filled_qty", 0) or 0) + filled_avg = data.get("filled_avg_price") + filled_avg_price = float(filled_avg) if filled_avg else None + + return OrderResponse( + broker_order_id=str(data.get("id", "")), + status=status, + ticker=str(data.get("symbol", "")), + side=side, + quantity=float(data.get("qty", 0) or 0), + filled_quantity=filled_qty, + filled_avg_price=filled_avg_price, + raw_response=data, + ) + + def _parse_position(self, data: dict[str, Any]) -> PositionInfo: + """Parse an Alpaca position response into a PositionInfo.""" + return PositionInfo( + ticker=str(data.get("symbol", "")), + quantity=float(data.get("qty", 0) or 0), + avg_entry_price=float(data.get("avg_entry_price", 0) or 0), + current_price=float(data.get("current_price", 0) or 0), + unrealized_pnl=float(data.get("unrealized_pl", 0) or 0), + market_value=float(data.get("market_value", 0) or 0), + side=str(data.get("side", "long")), + ) + + def _error_result( + self, + ticker: str, + error: str, + elapsed_ms: float, + http_status: int | None = None, + raw: bytes = b"", + ) -> AdapterResult: + """Build an error AdapterResult for broker fetches.""" + return AdapterResult( + source_type="broker", + ticker=ticker, + items=[], + raw_payload=raw, + content_hash="", + fetched_at=datetime.now(timezone.utc), + error=error, + http_status=http_status, + response_time_ms=round(elapsed_ms, 1), + metadata={"provider": "alpaca", "mode": self._mode.value}, + ) diff --git a/services/adapters/broker_service.py b/services/adapters/broker_service.py new file mode 100644 index 0000000..6695377 --- /dev/null +++ b/services/adapters/broker_service.py @@ -0,0 +1,832 @@ +"""Broker adapter service - standalone worker for sandbox order execution. + +Runs the Alpaca broker adapter in sandbox (paper) mode, processing order +requests from the broker queue, evaluating them through the risk engine, +submitting to Alpaca's paper trading API, and persisting the full audit trail. + +Also periodically syncs positions and account state from Alpaca. + +Implements idempotent order submission keys and duplicate prevention: +- Deterministic idempotency key generation from job attributes +- Redis-based fast-path duplicate detection before broker submission +- PostgreSQL UNIQUE constraint on idempotency_key as durable fallback + +Requirements: 2.4, 8.1, 8.3, 8.5 +Design: Section 4.9 - Broker Adapter +""" +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any + +import asyncpg +import redis.asyncio as aioredis + +from services.adapters.broker_adapter import ( + AlpacaBrokerAdapter, + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + TradingMode, +) +from services.risk.engine import ( + AccountRiskState, + PortfolioRiskConfig, + ProposedOrder, + evaluate_order, +) +from services.risk.approval import ( + ApprovalRequest, + ApprovalStatus, + compute_expiry, + create_approval_request, + requires_approval, +) +from services.shared.audit import ( + audit_approval_requested, + audit_duplicate_prevented, + audit_order_filled, + audit_order_rejected, + audit_order_submitted, + audit_risk_evaluated, +) +from services.lake_publisher.worker import ( + publish_trade_order, + publish_trade_fill, + publish_positions_daily_batch, + LAKEHOUSE_BUCKET, +) +from services.shared.config import load_config +from services.shared.db import get_pg_pool, get_redis +from services.shared.logging import Span, new_trace_id, set_trace_context, setup_logging +from services.shared.metrics import ( + ORDERS_DUPLICATES_PREVENTED, + ORDERS_FILLED, + ORDERS_REJECTED, + ORDERS_SUBMITTED, + POSITIONS_SYNCED, + RISK_CHECK_FAILURES, + RISK_EVALUATIONS_TOTAL, +) +from services.shared.redis_keys import QUEUE_BROKER, queue_key + +logger = logging.getLogger("broker_service") + +POSITION_SYNC_INTERVAL = 60 # seconds + +# Redis TTL for idempotency markers (24 hours) +ORDER_IDEMPOTENCY_TTL = 86400 +ORDER_IDEMPOTENCY_PREFIX = "stonks:order_idempotency" + + +# --------------------------------------------------------------------------- +# DB persistence helpers +# --------------------------------------------------------------------------- + +_UPSERT_BROKER_ACCOUNT = """ +INSERT INTO broker_accounts (id, provider, account_id, mode, config, active) +VALUES ($1::uuid, $2, $3, $4, $5::jsonb, TRUE) +ON CONFLICT (id) DO UPDATE SET + config = EXCLUDED.config, + mode = EXCLUDED.mode, + active = TRUE +""" + +_INSERT_ORDER = """ +INSERT INTO orders ( + id, recommendation_id, broker_account_id, ticker, side, order_type, + quantity, limit_price, stop_price, status, idempotency_key, + broker_order_id, decision_trace, submitted_at, filled_at, + fill_price, fill_quantity +) VALUES ( + $1::uuid, $2, $3::uuid, $4, $5, $6, + $7, $8, $9, $10, $11, + $12, $13::jsonb, $14, $15, + $16, $17 +) +ON CONFLICT (idempotency_key) DO UPDATE SET + status = EXCLUDED.status, + broker_order_id = EXCLUDED.broker_order_id, + filled_at = EXCLUDED.filled_at, + fill_price = EXCLUDED.fill_price, + fill_quantity = EXCLUDED.fill_quantity, + updated_at = NOW() +""" + +_INSERT_ORDER_EVENT = """ +INSERT INTO order_events (order_id, event_type, data, broker_timestamp) +VALUES ($1::uuid, $2, $3::jsonb, $4) +""" + +_INSERT_RISK_EVALUATION = """ +INSERT INTO risk_evaluations (id, recommendation_id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at) +VALUES ($1::uuid, $2::uuid, $3, $4, $5::jsonb, $6::jsonb, $7) +""" + +_UPSERT_POSITION = """ +INSERT INTO positions (broker_account_id, ticker, quantity, avg_entry_price, current_price, unrealized_pnl, updated_at) +VALUES ($1::uuid, $2, $3, $4, $5, $6, $7) +ON CONFLICT (broker_account_id, ticker) + DO UPDATE SET + quantity = EXCLUDED.quantity, + avg_entry_price = EXCLUDED.avg_entry_price, + current_price = EXCLUDED.current_price, + unrealized_pnl = EXCLUDED.unrealized_pnl, + updated_at = EXCLUDED.updated_at +""" + +_LOAD_RISK_CONFIG = """ +SELECT config FROM risk_configs WHERE active = TRUE ORDER BY updated_at DESC LIMIT 1 +""" + +_LOAD_DAILY_SNAPSHOT = """ +SELECT portfolio_value, daily_pnl, daily_trade_count, positions_by_sector +FROM daily_risk_snapshots +WHERE account_id = $1 AND snapshot_date = CURRENT_DATE +LIMIT 1 +""" + +_CHECK_ORDER_BY_IDEMPOTENCY_KEY = """ +SELECT id, status, broker_order_id FROM orders +WHERE idempotency_key = $1 +LIMIT 1 +""" + + +# --------------------------------------------------------------------------- +# Idempotency helpers (Requirement 8.5) +# --------------------------------------------------------------------------- + + +def generate_idempotency_key(job: dict[str, Any]) -> str: + """Generate a deterministic idempotency key from job attributes. + + If the job already carries an explicit idempotency_key, use it. + Otherwise, derive a stable key from the combination of + recommendation_id, ticker, side, quantity, and order_type so that + replayed queue messages produce the same key and are detected as + duplicates. + """ + explicit = job.get("idempotency_key") + if explicit: + return str(explicit) + + # Build a deterministic key from job content + parts = [ + str(job.get("recommendation_id", "")), + str(job.get("ticker", "")), + str(job.get("side", "buy")), + str(job.get("quantity", 0)), + str(job.get("order_type", "market")), + str(job.get("limit_price", "")), + str(job.get("stop_price", "")), + ] + raw = "|".join(parts) + return hashlib.sha256(raw.encode()).hexdigest()[:40] + + +def _redis_idempotency_key(idempotency_key: str) -> str: + """Build the Redis key for an order idempotency marker.""" + return f"{ORDER_IDEMPOTENCY_PREFIX}:{idempotency_key}" + + +async def check_idempotency_redis( + rds: aioredis.Redis, + idempotency_key: str, +) -> str | None: + """Fast-path: check Redis for a previously processed idempotency key. + + Returns the existing order_id if found, None otherwise. + """ + redis_key = _redis_idempotency_key(idempotency_key) + cached = await rds.get(redis_key) + if cached: + return str(cached) + return None + + +async def check_idempotency_db( + pool: asyncpg.Pool, + idempotency_key: str, +) -> dict[str, Any] | None: + """Durable fallback: check PostgreSQL for an existing order with this key. + + Returns a dict with id, status, broker_order_id if found, None otherwise. + """ + row = await pool.fetchrow(_CHECK_ORDER_BY_IDEMPOTENCY_KEY, idempotency_key) + if row: + return { + "id": str(row["id"]), + "status": str(row["status"]), + "broker_order_id": str(row["broker_order_id"] or ""), + } + return None + + +async def mark_idempotency_redis( + rds: aioredis.Redis, + idempotency_key: str, + order_id: str, +) -> None: + """Set the Redis idempotency marker after an order is processed.""" + redis_key = _redis_idempotency_key(idempotency_key) + await rds.set(redis_key, order_id, ex=ORDER_IDEMPOTENCY_TTL) + + +# --------------------------------------------------------------------------- +# Core service logic +# --------------------------------------------------------------------------- + + +def build_order_request(job: dict[str, Any]) -> OrderRequest: + """Build an OrderRequest from a broker queue job payload.""" + side = OrderSide.SELL if job.get("side", "buy") == "sell" else OrderSide.BUY + order_type_str = job.get("order_type", "market") + order_type_map = { + "market": OrderType.MARKET, + "limit": OrderType.LIMIT, + "stop": OrderType.STOP, + "stop_limit": OrderType.STOP_LIMIT, + } + return OrderRequest( + ticker=job["ticker"], + side=side, + quantity=float(job.get("quantity", 0)), + order_type=order_type_map.get(order_type_str, OrderType.MARKET), + limit_price=job.get("limit_price"), + stop_price=job.get("stop_price"), + time_in_force=job.get("time_in_force", "day"), + idempotency_key=generate_idempotency_key(job), + ) + + +def build_proposed_order(job: dict[str, Any]) -> ProposedOrder: + """Build a ProposedOrder for risk evaluation from a broker queue job.""" + return ProposedOrder( + recommendation_id=job.get("recommendation_id"), + ticker=job["ticker"], + sector=job.get("sector", ""), + action=job.get("side", "buy"), + quantity=float(job.get("quantity", 0)), + estimated_value=float(job.get("estimated_value", 0)), + confidence=float(job.get("confidence", 0)), + ) + + +async def load_risk_config(pool: asyncpg.Pool) -> PortfolioRiskConfig: + """Load the active risk configuration from the database.""" + row = await pool.fetchrow(_LOAD_RISK_CONFIG) + if row and row["config"]: + data = row["config"] if isinstance(row["config"], dict) else json.loads(row["config"]) + return PortfolioRiskConfig.from_db_json(data) + return PortfolioRiskConfig() + + +async def load_account_risk_state( + pool: asyncpg.Pool, + adapter: AlpacaBrokerAdapter, + account_uuid: str, +) -> AccountRiskState: + """Build an AccountRiskState from the broker and daily snapshot.""" + state = AccountRiskState(account_id=account_uuid) + + # Get live account info from Alpaca + try: + acct = await adapter.get_account() + state.portfolio_value = acct.portfolio_value + state.cash = acct.cash + state.buying_power = acct.buying_power + except Exception as e: + logger.warning("Failed to fetch account from Alpaca: %s", e) + + # Get positions from Alpaca + try: + positions = await adapter.get_positions() + for pos in positions: + state.positions_by_symbol[pos.ticker] = pos.market_value + state.open_position_count = len(positions) + except Exception as e: + logger.warning("Failed to fetch positions from Alpaca: %s", e) + + # Overlay daily snapshot from DB + row = await pool.fetchrow(_LOAD_DAILY_SNAPSHOT, account_uuid) + if row: + state.daily_pnl = float(row["daily_pnl"] or 0) + state.daily_trade_count = int(row["daily_trade_count"] or 0) + sector_data = row["positions_by_sector"] + if sector_data: + state.positions_by_sector = ( + sector_data if isinstance(sector_data, dict) else json.loads(sector_data) + ) + + return state + + +async def persist_order( + pool: asyncpg.Pool, + order_id: str, + order: OrderRequest, + resp: OrderResponse, + account_uuid: str, + risk_eval: dict[str, Any], + recommendation_id: str | None = None, +) -> None: + """Persist order, events, and risk evaluation to PostgreSQL.""" + now = datetime.now(timezone.utc) + filled_at = now if resp.status == OrderStatus.FILLED else None + + decision_trace = { + "risk_evaluation": risk_eval, + "order_request": order.to_dict(), + "broker_response": resp.to_dict(), + } + + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.execute( + _INSERT_ORDER, + order_id, + recommendation_id, + account_uuid, + order.ticker, + order.side.value, + order.order_type.value, + order.quantity, + order.limit_price, + order.stop_price, + resp.status.value, + order.idempotency_key, + resp.broker_order_id, + json.dumps(decision_trace), + resp.submitted_at or now, + filled_at, + resp.filled_avg_price, + resp.filled_quantity, + ) + + # Record order events + for event_type in ["submitted"]: + await conn.execute( + _INSERT_ORDER_EVENT, + order_id, + event_type, + json.dumps({"ticker": order.ticker, "side": order.side.value}), + now, + ) + + if resp.status == OrderStatus.FILLED: + await conn.execute( + _INSERT_ORDER_EVENT, + order_id, + "fill", + json.dumps({ + "fill_price": resp.filled_avg_price, + "fill_qty": resp.filled_quantity, + }), + now, + ) + elif resp.status == OrderStatus.REJECTED: + await conn.execute( + _INSERT_ORDER_EVENT, + order_id, + "rejected", + json.dumps({"error": resp.error}), + now, + ) + + +async def sync_positions( + adapter: AlpacaBrokerAdapter, + pool: asyncpg.Pool, + account_uuid: str, + minio_client: Any | None = None, +) -> None: + """Sync current positions from Alpaca to PostgreSQL and publish to lake.""" + now = datetime.now(timezone.utc) + try: + positions = await adapter.get_positions() + async with pool.acquire() as conn: + for pos in positions: + await conn.execute( + _UPSERT_POSITION, + account_uuid, + pos.ticker, + pos.quantity, + pos.avg_entry_price, + pos.current_price, + pos.unrealized_pnl, + now, + ) + logger.info("Synced %d positions from Alpaca", len(positions)) + POSITIONS_SYNCED.inc() + + # Publish positions snapshot to analytical lake + if minio_client is not None and positions: + try: + pos_dicts = [ + { + "ticker": p.ticker, + "quantity": p.quantity, + "avg_entry_price": p.avg_entry_price, + "close_price": p.current_price, + "unrealized_pnl": p.unrealized_pnl, + } + for p in positions + ] + publish_positions_daily_batch( + minio_client, pos_dicts, account_uuid, now, + ) + except Exception as e: + logger.warning("Failed to publish positions to lake: %s", e) + except Exception as e: + logger.error("Position sync failed: %s", e) + + +async def register_broker_account( + pool: asyncpg.Pool, + account_uuid: str, + adapter: AlpacaBrokerAdapter, +) -> None: + """Register or update the broker account in PostgreSQL.""" + try: + acct = await adapter.get_account() + config_json = json.dumps({ + "provider": "alpaca", + "buying_power": acct.buying_power, + "cash": acct.cash, + "portfolio_value": acct.portfolio_value, + }) + await pool.execute( + _UPSERT_BROKER_ACCOUNT, + account_uuid, + "alpaca", + acct.account_id or account_uuid, + adapter.mode.value, + config_json, + ) + logger.info( + "Registered Alpaca account: id=%s mode=%s portfolio=%.2f", + acct.account_id, adapter.mode.value, acct.portfolio_value, + ) + except Exception as e: + logger.error("Failed to register broker account: %s", e) + + +async def process_order_job( + job: dict[str, Any], + adapter: AlpacaBrokerAdapter, + pool: asyncpg.Pool, + account_uuid: str, + rds: aioredis.Redis | None = None, + minio_client: Any | None = None, +) -> None: + """Process a single order job from the broker queue. + + 1. Generate deterministic idempotency key + 2. Check Redis + DB for duplicate (Req 8.5) + 3. Build proposed order and run risk evaluation + 4. If risk passes, submit to Alpaca + 5. Persist order, events, and risk evaluation + 6. Set Redis idempotency marker + """ + ticker = job.get("ticker", "???") + order_id = str(uuid.uuid4()) + idempotency_key = generate_idempotency_key(job) + + # --- Duplicate prevention (Requirement 8.5) --- + # Fast path: Redis check + if rds is not None: + existing_order_id = await check_idempotency_redis(rds, idempotency_key) + if existing_order_id: + logger.info( + "Duplicate order detected (redis) for %s key=%s existing=%s", + ticker, idempotency_key[:16], existing_order_id, + ) + ORDERS_DUPLICATES_PREVENTED.labels(detected_via="redis").inc() + await audit_duplicate_prevented( + pool, existing_order_id, ticker, idempotency_key, detected_via="redis", + ) + return + + # Durable fallback: DB check + existing = await check_idempotency_db(pool, idempotency_key) + if existing: + logger.info( + "Duplicate order detected (db) for %s key=%s existing=%s status=%s", + ticker, idempotency_key[:16], existing["id"], existing["status"], + ) + ORDERS_DUPLICATES_PREVENTED.labels(detected_via="db").inc() + await audit_duplicate_prevented( + pool, existing["id"], ticker, idempotency_key, detected_via="db", + ) + # Warm Redis cache for future fast-path hits + if rds is not None: + await mark_idempotency_redis(rds, idempotency_key, existing["id"]) + return + + # Risk evaluation + risk_config = await load_risk_config(pool) + risk_state = await load_account_risk_state(pool, adapter, account_uuid) + proposed = build_proposed_order(job) + evaluation = evaluate_order(proposed, risk_config, risk_state) + + risk_eval_dict = { + "evaluation_id": evaluation.evaluation_id, + "eligible": evaluation.eligible, + "allowed_mode": evaluation.allowed_mode.value, + "rejection_reasons": evaluation.rejection_reasons, + "checks": [c.model_dump(mode="json") for c in evaluation.checks], + } + + # Persist risk evaluation + rec_id = job.get("recommendation_id") + try: + await pool.execute( + _INSERT_RISK_EVALUATION, + evaluation.evaluation_id, + rec_id, + evaluation.eligible, + evaluation.allowed_mode.value, + json.dumps(evaluation.rejection_reasons), + json.dumps(risk_eval_dict["checks"]), + evaluation.evaluated_at, + ) + except Exception as e: + logger.warning("Failed to persist risk evaluation: %s", e) + + # Audit: risk evaluation result + await audit_risk_evaluated( + pool, + evaluation_id=evaluation.evaluation_id, + recommendation_id=rec_id, + ticker=ticker, + eligible=evaluation.eligible, + allowed_mode=evaluation.allowed_mode.value, + rejection_reasons=evaluation.rejection_reasons, + check_count=len(evaluation.checks), + ) + + if not evaluation.eligible: + RISK_EVALUATIONS_TOTAL.labels(result="rejected").inc() + for check in evaluation.checks: + if check.result.value == "fail": + RISK_CHECK_FAILURES.labels(check_name=check.check_name).inc() + ORDERS_REJECTED.labels(reason_category="risk_engine").inc() + logger.info( + "Order rejected by risk engine for %s: %s", + ticker, evaluation.rejection_reasons, + ) + # Persist the rejected order for audit + order_req = build_order_request(job) + rejected_resp = OrderResponse( + broker_order_id="", + status=OrderStatus.REJECTED, + ticker=ticker, + side=OrderSide.SELL if job.get("side") == "sell" else OrderSide.BUY, + quantity=float(job.get("quantity", 0)), + error=f"Risk rejected: {'; '.join(evaluation.rejection_reasons)}", + ) + await persist_order( + pool, order_id, order_req, rejected_resp, + account_uuid, risk_eval_dict, rec_id, + ) + # Publish rejected order fact to analytical lake + if minio_client is not None: + try: + publish_trade_order( + minio_client, order_id, ticker, + side=job.get("side", "buy"), + order_type=job.get("order_type", "market"), + quantity=float(job.get("quantity", 0)), + limit_price=job.get("limit_price"), + status="rejected", + broker_account=account_uuid, + submitted_at=datetime.now(timezone.utc), + ) + except Exception as e: + logger.warning("Failed to publish rejected order to lake: %s", e) + # Audit: order rejected by risk engine + await audit_order_rejected( + pool, order_id, ticker, + reason=f"Risk rejected: {'; '.join(evaluation.rejection_reasons)}", + source="risk_engine", + ) + # Mark idempotency even for rejected orders to prevent reprocessing + if rds is not None: + await mark_idempotency_redis(rds, idempotency_key, order_id) + return + + # --- Operator approval gate (Requirement 8.2) --- + if requires_approval(risk_config, evaluation.allowed_mode): + expiry = compute_expiry(risk_config) + approval_req = ApprovalRequest( + order_job=job, + recommendation_id=rec_id, + ticker=ticker, + side=job.get("side", "buy"), + quantity=float(job.get("quantity", 0)), + estimated_value=float(job.get("estimated_value", 0)), + risk_evaluation_id=evaluation.evaluation_id, + expires_at=expiry, + ) + try: + await create_approval_request(pool, approval_req) + logger.info( + "Order for %s held for operator approval (id=%s, expires=%s)", + ticker, approval_req.approval_id, expiry.isoformat(), + ) + await audit_approval_requested( + pool, + approval_id=approval_req.approval_id, + ticker=ticker, + side=approval_req.side, + quantity=approval_req.quantity, + estimated_value=approval_req.estimated_value, + recommendation_id=rec_id, + expires_at=expiry.isoformat(), + ) + except Exception as e: + logger.error("Failed to create approval request for %s: %s", ticker, e) + # Do NOT mark idempotency — the job will be re-submitted after approval + return + + # Submit to Alpaca + order_req = build_order_request(job) + RISK_EVALUATIONS_TOTAL.labels(result="passed").inc() + + # Audit: order submitted to broker + await audit_order_submitted( + pool, + order_id=order_id, + ticker=ticker, + side=order_req.side.value, + quantity=order_req.quantity, + order_type=order_req.order_type.value, + idempotency_key=order_req.idempotency_key, + recommendation_id=rec_id, + evaluation_id=evaluation.evaluation_id, + ) + + resp = await adapter.submit_order(order_req) + + await persist_order( + pool, order_id, order_req, resp, + account_uuid, risk_eval_dict, rec_id, + ) + + # Publish order fact to analytical lake + if minio_client is not None: + try: + publish_trade_order( + minio_client, order_id, ticker, + side=order_req.side.value, + order_type=order_req.order_type.value, + quantity=order_req.quantity, + limit_price=order_req.limit_price, + status=resp.status.value, + broker_account=account_uuid, + submitted_at=resp.submitted_at or datetime.now(timezone.utc), + ) + except Exception as e: + logger.warning("Failed to publish order to lake: %s", e) + + # Publish fill fact if the order was filled + if resp.status == OrderStatus.FILLED and resp.filled_avg_price is not None: + try: + fill_id = str(uuid.uuid4()) + publish_trade_fill( + minio_client, fill_id, order_id, ticker, + side=order_req.side.value, + fill_price=resp.filled_avg_price, + fill_quantity=resp.filled_quantity, + broker_account=account_uuid, + filled_at=datetime.now(timezone.utc), + ) + except Exception as e: + logger.warning("Failed to publish fill to lake: %s", e) + + # Mark idempotency after successful persistence + if rds is not None: + await mark_idempotency_redis(rds, idempotency_key, order_id) + + if resp.ok: + mode = "paper" if adapter.mode == TradingMode.PAPER else "live" + ORDERS_SUBMITTED.labels( + side=order_req.side.value, + order_type=order_req.order_type.value, + mode=mode, + ).inc() + logger.info( + "Order submitted to Alpaca: %s %s %.0f %s @ %s | broker_id=%s", + resp.status.value, order_req.side.value, order_req.quantity, + ticker, resp.filled_avg_price, resp.broker_order_id, + ) + # Audit: order filled + if resp.status == OrderStatus.FILLED: + ORDERS_FILLED.labels(side=order_req.side.value).inc() + await audit_order_filled( + pool, order_id, ticker, + side=order_req.side.value, + fill_quantity=resp.filled_quantity, + fill_price=resp.filled_avg_price, + broker_order_id=resp.broker_order_id, + ) + else: + ORDERS_REJECTED.labels(reason_category="broker").inc() + logger.warning( + "Order failed for %s: %s (status=%s)", + ticker, resp.error, resp.status.value, + ) + # Audit: order rejected by broker + await audit_order_rejected( + pool, order_id, ticker, + reason=resp.error or f"Broker status: {resp.status.value}", + source="broker", + ) + + + +async def position_sync_loop( + adapter: AlpacaBrokerAdapter, + pool: asyncpg.Pool, + account_uuid: str, + minio_client: Any | None = None, +) -> None: + """Periodically sync positions from Alpaca to PostgreSQL and lake.""" + while True: + await sync_positions(adapter, pool, account_uuid, minio_client) + await asyncio.sleep(POSITION_SYNC_INTERVAL) + + +async def main() -> None: + config = load_config() + setup_logging("broker_service", level=config.log_level, json_output=config.json_logs) + + pool = await get_pg_pool(config) + rds = get_redis(config) + + # Initialize MinIO client for lake publishing + from minio import Minio + minio_client = Minio( + config.minio.endpoint, + access_key=config.minio.access_key, + secret_key=config.minio.secret_key, + secure=config.minio.secure, + ) + # Ensure lakehouse bucket exists + if not minio_client.bucket_exists(LAKEHOUSE_BUCKET): + minio_client.make_bucket(LAKEHOUSE_BUCKET) + + # Determine mode — default to paper for safety (Req 8.1) + mode = TradingMode.LIVE if config.broker.mode == "live" else TradingMode.PAPER + if mode == TradingMode.LIVE: + logger.warning("LIVE trading mode enabled — orders will be submitted to real broker") + + adapter = AlpacaBrokerAdapter( + api_key=config.broker.api_key or "", + api_secret=config.broker.api_secret or "", + mode=mode, + base_url=config.broker.base_url, + ) + + # Generate a stable account UUID from the API key + account_uuid = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"alpaca-{config.broker.api_key or 'default'}")) + + # Register broker account on startup + await register_broker_account(pool, account_uuid, adapter) + + # Start position sync in background + sync_task = asyncio.create_task( + position_sync_loop(adapter, pool, account_uuid, minio_client) + ) + + queue = queue_key(QUEUE_BROKER) + logger.info("Broker service started (mode=%s)", mode.value) + + try: + while True: + result = await rds.lpop(queue) + raw = str(result) if result else None + if raw: + try: + job = json.loads(raw) + await process_order_job(job, adapter, pool, account_uuid, rds, minio_client) + except Exception: + logger.exception("Error processing broker job") + else: + await asyncio.sleep(2) + finally: + sync_task.cancel() + await pool.close() + await rds.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/services/adapters/filings_adapter.py b/services/adapters/filings_adapter.py index ede9df7..0c67461 100644 --- a/services/adapters/filings_adapter.py +++ b/services/adapters/filings_adapter.py @@ -1,8 +1,17 @@ -"""Filings / Regulatory API adapter - fetches SEC-style submissions.""" +"""Filings / Regulatory API adapter interface and concrete SEC EDGAR provider. + +The FilingsDataAdapter is the abstract interface for all filings data providers. +SECEdgarAdapter is the first concrete implementation, targeting the SEC EDGAR +full-text search system (EFTS) for company filings discovery. + +Requirements: 2.3, 2.5, 3.1, 3.2, 3.3 +""" import hashlib import logging -from datetime import datetime -from typing import Any, Dict +import time +from abc import ABC +from datetime import datetime, timezone +from typing import Any import httpx @@ -11,48 +20,182 @@ from .base import AdapterResult, BaseAdapter logger = logging.getLogger("filings_adapter") -class FilingsAdapter(BaseAdapter): - """Concrete adapter for SEC EDGAR or similar filings API.""" +class FilingsDataAdapter(BaseAdapter, ABC): + """Abstract interface for filings / regulatory data providers. - def __init__(self, base_url: str = "https://efts.sec.gov", user_agent: str = "StonksOracle/1.0"): - self.base_url = base_url - self.user_agent = user_agent + Subclasses implement fetch() for their specific filings API. + source_type() is concrete here since all filings adapters share the same type. + """ def source_type(self) -> str: return "filings_api" - async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult: - _cik = config.get("cik", "") - endpoint = config.get("endpoint", f"/LATEST/search-index?q=%22{ticker}%22&dateRange=custom&startdt=2026-01-01&forms=8-K,10-Q,10-K") - url = f"{self.base_url}{endpoint}" - headers = {"User-Agent": self.user_agent} +class SECEdgarAdapter(FilingsDataAdapter): + """Concrete adapter for the SEC EDGAR full-text search system (EFTS). + + Supports: + - Full-text search (/LATEST/search-index) for 8-K, 10-Q, 10-K, and other forms + - Filtering by date range, form type, and entity + + The SEC EDGAR EFTS API is public and does not require an API key, + but requires a descriptive User-Agent header per SEC fair-access policy. + + Config options: + cik: Company CIK number (optional, narrows search) + forms: Comma-separated form types to search (default "8-K,10-Q,10-K") + start_date: Only filings on or after this date, YYYY-MM-DD (optional) + end_date: Only filings on or before this date, YYYY-MM-DD (optional) + query: Custom search query override (optional, replaces ticker-based query) + """ + + SEARCH_ENDPOINT: str = "/LATEST/search-index" + + def __init__( + self, + base_url: str = "https://efts.sec.gov", + user_agent: str = "StonksOracle/1.0 ([email])", + ) -> None: + self.base_url: str = base_url.rstrip("/") + self.user_agent: str = user_agent + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch filings from SEC EDGAR EFTS for a given ticker. + + Args: + ticker: The company ticker symbol. + config: Source-specific configuration from the sources table. + + Returns: + AdapterResult with raw payload, parsed filing items, and metadata. + """ + url, params, headers = self._build_request(ticker, config) async with httpx.AsyncClient(timeout=30) as client: + t0 = time.monotonic() try: - resp = await client.get(url, headers=headers) + resp = await client.get(url, params=params, headers=headers) + elapsed_ms = (time.monotonic() - t0) * 1000 resp.raise_for_status() + raw = resp.content data = resp.json() content_hash = hashlib.sha256(raw).hexdigest() + items = self._extract_items(data) - hits = data.get("hits", {}).get("hits", []) return AdapterResult( source_type="filings_api", ticker=ticker, - items=hits, + items=items, raw_payload=raw, content_hash=content_hash, - fetched_at=datetime.utcnow(), + fetched_at=datetime.now(timezone.utc), + http_status=resp.status_code, + response_time_ms=round(elapsed_ms, 1), + metadata={ + "provider": "sec_edgar", + "results_count": len(items), + "total_hits": self._total_hits(data), + "query": params.get("q", ""), + "forms": params.get("forms", ""), + }, ) + except httpx.HTTPStatusError as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("SEC EDGAR HTTP error for %s: %s", ticker, e) + return self._error_result( + ticker, str(e), elapsed_ms, + http_status=e.response.status_code if e.response else None, + raw=e.response.content if e.response else b"", + ) + except httpx.TimeoutException as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("SEC EDGAR timeout for %s: %s", ticker, e) + return self._error_result(ticker, f"timeout: {e}", elapsed_ms) except Exception as e: - logger.error(f"Filings fetch failed for {ticker}: {e}") - return AdapterResult( - source_type="filings_api", - ticker=ticker, - items=[], - raw_payload=b"", - content_hash="", - fetched_at=datetime.utcnow(), - error=str(e), - ) + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("SEC EDGAR fetch failed for %s: %s", ticker, e) + return self._error_result(ticker, str(e), elapsed_ms) + + def _build_request( + self, ticker: str, config: dict[str, Any] + ) -> tuple[str, dict[str, str], dict[str, str]]: + """Build the URL, query params, and headers for an EDGAR EFTS request.""" + params: dict[str, str] = {} + headers: dict[str, str] = {"User-Agent": self.user_agent} + + # Query: use custom override or default to ticker-based search + query = config.get("query") + if query: + params["q"] = str(query) + else: + params["q"] = f'"{ticker}"' + + # Form types filter + forms = config.get("forms", "8-K,10-Q,10-K") + params["forms"] = str(forms) + + # Date range + if config.get("start_date"): + params["dateRange"] = "custom" + params["startdt"] = str(config["start_date"]) + if config.get("end_date"): + params["dateRange"] = "custom" + params["enddt"] = str(config["end_date"]) + + # CIK filter (entity-level narrowing) + cik = config.get("cik") + if cik: + params["q"] = f'{params["q"]} AND cik:{cik}' + + url = f"{self.base_url}{self.SEARCH_ENDPOINT}" + return url, params, headers + + def _extract_items(self, data: dict[str, Any]) -> list[dict[str, Any]]: + """Extract the filing hits from an EDGAR EFTS response. + + EFTS returns results under hits.hits as a list of objects, + each containing _source with fields like file_date, form_type, + entity_name, file_num, and period_of_report. + """ + hits_wrapper = data.get("hits", {}) + if not isinstance(hits_wrapper, dict): + return [] + hits = hits_wrapper.get("hits", []) + if isinstance(hits, list): + return hits + return [] + + def _total_hits(self, data: dict[str, Any]) -> int: + """Extract total hit count from EFTS response.""" + hits_wrapper = data.get("hits", {}) + if not isinstance(hits_wrapper, dict): + return 0 + total = hits_wrapper.get("total", {}) + if isinstance(total, dict): + return int(total.get("value", 0)) + if isinstance(total, int): + return total + return 0 + + def _error_result( + self, + ticker: str, + error: str, + elapsed_ms: float, + http_status: int | None = None, + raw: bytes = b"", + ) -> AdapterResult: + """Build an error AdapterResult for filings fetches.""" + return AdapterResult( + source_type="filings_api", + ticker=ticker, + items=[], + raw_payload=raw, + content_hash="", + fetched_at=datetime.now(timezone.utc), + error=error, + http_status=http_status, + response_time_ms=round(elapsed_ms, 1), + metadata={"provider": "sec_edgar"}, + ) diff --git a/services/adapters/market_adapter.py b/services/adapters/market_adapter.py index d33c2ff..441afc9 100644 --- a/services/adapters/market_adapter.py +++ b/services/adapters/market_adapter.py @@ -1,8 +1,16 @@ -"""Market data API adapter - fetches quotes, bars, and reference data.""" +"""Market data API adapter interface and concrete Polygon.io provider. + +The MarketDataAdapter is the abstract interface for all market data providers. +PolygonMarketAdapter is the first concrete implementation, targeting the +Polygon.io REST API for previous-day bars, quotes, and ticker details. + +Requirements: 2.1, 2.5, 3.1, 3.2, 3.3 +""" import hashlib import logging -from datetime import datetime -from typing import Any, Dict +import time +from datetime import datetime, timezone +from typing import Any import httpx @@ -12,48 +20,158 @@ logger = logging.getLogger("market_adapter") class MarketDataAdapter(BaseAdapter): - """Concrete adapter for a market data provider (e.g., Alpha Vantage, Polygon, Yahoo).""" + """Abstract interface for market data providers. - def __init__(self, api_key: str = "", base_url: str = ""): - self.api_key = api_key - self.base_url = base_url + Subclasses implement fetch() for their specific market data API. + """ def source_type(self) -> str: return "market_api" - async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult: - endpoint = config.get("endpoint", "/v2/aggs/ticker/{ticker}/prev") - url = f"{self.base_url}{endpoint.format(ticker=ticker)}" - params = config.get("params", {}) - if self.api_key: - params["apiKey"] = self.api_key + +class PolygonMarketAdapter(MarketDataAdapter): + """Concrete adapter for the Polygon.io REST API. + + Supports: + - Previous-day aggregate bars (/v2/aggs/ticker/{ticker}/prev) + - Grouped daily bars (/v2/aggs/grouped/locale/us/market/stocks/{date}) + - Ticker details (/v3/reference/tickers/{ticker}) + + The endpoint is selected via the source config's "endpoint" field, + defaulting to previous-day bars. + """ + + PREV_BARS = "/v2/aggs/ticker/{ticker}/prev" + RANGE_BARS = "/v2/aggs/ticker/{ticker}/range/{multiplier}/{timespan}/{from_date}/{to_date}" + TICKER_DETAILS = "/v3/reference/tickers/{ticker}" + + def __init__(self, api_key: str, base_url: str = "https://api.polygon.io") -> None: + self.api_key: str = api_key + self.base_url: str = base_url.rstrip("/") + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch market data from Polygon.io for a given ticker. + + Config options: + endpoint: One of "prev_bars" (default), "range_bars", "ticker_details" + multiplier: Bar multiplier for range queries (default 1) + timespan: Bar timespan for range queries (default "day") + from_date: Start date for range queries (YYYY-MM-DD) + to_date: End date for range queries (YYYY-MM-DD) + adjusted: Whether bars are adjusted for splits (default true) + """ + endpoint_key = config.get("endpoint", "prev_bars") + url, params = self._build_request(ticker, endpoint_key, config) async with httpx.AsyncClient(timeout=30) as client: + t0 = time.monotonic() try: resp = await client.get(url, params=params) + elapsed_ms = (time.monotonic() - t0) * 1000 resp.raise_for_status() + raw = resp.content data = resp.json() content_hash = hashlib.sha256(raw).hexdigest() - - items = data.get("results", [data]) if isinstance(data, dict) else data + items = self._extract_items(data, endpoint_key) return AdapterResult( source_type="market_api", ticker=ticker, - items=items if isinstance(items, list) else [items], + items=items, raw_payload=raw, content_hash=content_hash, - fetched_at=datetime.utcnow(), + fetched_at=datetime.now(timezone.utc), + http_status=resp.status_code, + response_time_ms=round(elapsed_ms, 1), + metadata={ + "provider": "polygon", + "endpoint": endpoint_key, + "results_count": data.get("resultsCount", len(items)), + "request_id": data.get("request_id", ""), + }, ) + except httpx.HTTPStatusError as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Polygon HTTP error for %s: %s", ticker, e) + return self._error_result( + ticker, str(e), elapsed_ms, + http_status=e.response.status_code if e.response else None, + raw=e.response.content if e.response else b"", + ) + except httpx.TimeoutException as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Polygon timeout for %s: %s", ticker, e) + return self._error_result(ticker, f"timeout: {e}", elapsed_ms) except Exception as e: - logger.error(f"Market fetch failed for {ticker}: {e}") - return AdapterResult( - source_type="market_api", - ticker=ticker, - items=[], - raw_payload=b"", - content_hash="", - fetched_at=datetime.utcnow(), - error=str(e), - ) + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Polygon fetch failed for %s: %s", ticker, e) + return self._error_result(ticker, str(e), elapsed_ms) + + def _build_request( + self, ticker: str, endpoint_key: str, config: dict[str, Any] + ) -> tuple[str, dict[str, str]]: + """Build the URL and query params for a Polygon request.""" + params: dict[str, str] = {"apiKey": self.api_key} + + if endpoint_key == "range_bars": + multiplier = str(config.get("multiplier", 1)) + timespan = config.get("timespan", "day") + from_date = config.get("from_date", "") + to_date = config.get("to_date", "") + path = self.RANGE_BARS.format( + ticker=ticker, + multiplier=multiplier, + timespan=timespan, + from_date=from_date, + to_date=to_date, + ) + if config.get("adjusted") is not None: + params["adjusted"] = str(config["adjusted"]).lower() + if config.get("sort"): + params["sort"] = config["sort"] + if config.get("limit"): + params["limit"] = str(config["limit"]) + elif endpoint_key == "ticker_details": + path = self.TICKER_DETAILS.format(ticker=ticker) + else: + # Default: previous-day bars + path = self.PREV_BARS.format(ticker=ticker) + if config.get("adjusted") is not None: + params["adjusted"] = str(config["adjusted"]).lower() + + return f"{self.base_url}{path}", params + + def _extract_items(self, data: dict[str, Any], endpoint_key: str) -> list[dict[str, Any]]: + """Extract the relevant items list from a Polygon response.""" + if endpoint_key == "ticker_details": + results = data.get("results", {}) + return [results] if isinstance(results, dict) and results else [] + + # Aggregate endpoints return results as a list + results = data.get("results", []) + if isinstance(results, list): + return results + return [results] if results else [] + + def _error_result( + self, + ticker: str, + error: str, + elapsed_ms: float, + http_status: int | None = None, + raw: bytes = b"", + ) -> AdapterResult: + """Build an error AdapterResult.""" + return AdapterResult( + source_type="market_api", + ticker=ticker, + items=[], + raw_payload=raw, + content_hash="", + fetched_at=datetime.now(timezone.utc), + error=error, + http_status=http_status, + response_time_ms=round(elapsed_ms, 1), + metadata={"provider": "polygon"}, + ) diff --git a/services/adapters/news_adapter.py b/services/adapters/news_adapter.py index 0a77c7d..ddca7f3 100644 --- a/services/adapters/news_adapter.py +++ b/services/adapters/news_adapter.py @@ -1,8 +1,17 @@ -"""News API adapter - fetches company-linked headlines and article metadata.""" +"""News API adapter interface and concrete Polygon.io news provider. + +The NewsDataAdapter is the abstract interface for all news data providers. +PolygonNewsAdapter is the first concrete implementation, targeting the +Polygon.io REST API for company-linked news articles and headlines. + +Requirements: 2.2, 2.5, 3.1, 3.2, 3.3 +""" import hashlib import logging -from datetime import datetime -from typing import Any, Dict +import time +from abc import ABC +from datetime import datetime, timezone +from typing import Any import httpx @@ -11,51 +20,147 @@ from .base import AdapterResult, BaseAdapter logger = logging.getLogger("news_adapter") -class NewsApiAdapter(BaseAdapter): - """Concrete adapter for a news API provider.""" +class NewsDataAdapter(BaseAdapter, ABC): + """Abstract interface for news data providers. - def __init__(self, api_key: str = "", base_url: str = ""): - self.api_key = api_key - self.base_url = base_url + Subclasses implement fetch() for their specific news API. + source_type() is concrete here since all news adapters share the same type. + """ def source_type(self) -> str: return "news_api" - async def fetch(self, ticker: str, config: Dict[str, Any]) -> AdapterResult: - endpoint = config.get("endpoint", "/v2/everything") - url = f"{self.base_url}{endpoint}" - params = config.get("params", {}) - params.setdefault("q", ticker) - params.setdefault("sortBy", "publishedAt") - params.setdefault("pageSize", 20) - if self.api_key: - params["apiKey"] = self.api_key + +class PolygonNewsAdapter(NewsDataAdapter): + """Concrete adapter for the Polygon.io ticker news endpoint. + + Supports: + - Ticker news (/v2/reference/news?ticker={ticker}) + + Config options: + limit: Max articles to return per request (default 20, max 1000) + published_utc_gte: Only articles published on or after this date (YYYY-MM-DD) + published_utc_lte: Only articles published on or before this date (YYYY-MM-DD) + order: Sort order for results, "asc" or "desc" (default "desc") + """ + + NEWS_ENDPOINT = "/v2/reference/news" + + def __init__(self, api_key: str, base_url: str = "https://api.polygon.io") -> None: + self.api_key: str = api_key + self.base_url: str = base_url.rstrip("/") + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch news articles from Polygon.io for a given ticker. + + Args: + ticker: The company ticker symbol. + config: Source-specific configuration from the sources table. + + Returns: + AdapterResult with raw payload, parsed article items, and metadata. + """ + url, params = self._build_request(ticker, config) async with httpx.AsyncClient(timeout=30) as client: + t0 = time.monotonic() try: resp = await client.get(url, params=params) + elapsed_ms = (time.monotonic() - t0) * 1000 resp.raise_for_status() + raw = resp.content data = resp.json() content_hash = hashlib.sha256(raw).hexdigest() + items = self._extract_items(data) - articles = data.get("articles", []) return AdapterResult( source_type="news_api", ticker=ticker, - items=articles, + items=items, raw_payload=raw, content_hash=content_hash, - fetched_at=datetime.utcnow(), + fetched_at=datetime.now(timezone.utc), + http_status=resp.status_code, + response_time_ms=round(elapsed_ms, 1), + metadata={ + "provider": "polygon", + "results_count": data.get("count", len(items)), + "next_url": data.get("next_url", ""), + "request_id": data.get("request_id", ""), + }, ) + except httpx.HTTPStatusError as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Polygon news HTTP error for %s: %s", ticker, e) + return self._error_result( + ticker, str(e), elapsed_ms, + http_status=e.response.status_code if e.response else None, + raw=e.response.content if e.response else b"", + ) + except httpx.TimeoutException as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Polygon news timeout for %s: %s", ticker, e) + return self._error_result(ticker, f"timeout: {e}", elapsed_ms) except Exception as e: - logger.error(f"News fetch failed for {ticker}: {e}") - return AdapterResult( - source_type="news_api", - ticker=ticker, - items=[], - raw_payload=b"", - content_hash="", - fetched_at=datetime.utcnow(), - error=str(e), - ) + elapsed_ms = (time.monotonic() - t0) * 1000 + logger.error("Polygon news fetch failed for %s: %s", ticker, e) + return self._error_result(ticker, str(e), elapsed_ms) + + def _build_request( + self, ticker: str, config: dict[str, Any] + ) -> tuple[str, dict[str, str]]: + """Build the URL and query params for a Polygon news request.""" + params: dict[str, str] = { + "apiKey": self.api_key, + "ticker": ticker, + } + + limit = config.get("limit", 20) + params["limit"] = str(min(int(limit), 1000)) + + if config.get("order"): + params["order"] = config["order"] + + if config.get("published_utc_gte"): + params["published_utc.gte"] = config["published_utc_gte"] + + if config.get("published_utc_lte"): + params["published_utc.lte"] = config["published_utc_lte"] + + url = f"{self.base_url}{self.NEWS_ENDPOINT}" + return url, params + + def _extract_items(self, data: dict[str, Any]) -> list[dict[str, Any]]: + """Extract the article list from a Polygon news response. + + Polygon returns articles under the "results" key as a list of objects, + each containing fields like id, publisher, title, article_url, tickers, + published_utc, description, and keywords. + """ + results = data.get("results", []) + if isinstance(results, list): + return results + return [] + + def _error_result( + self, + ticker: str, + error: str, + elapsed_ms: float, + http_status: int | None = None, + raw: bytes = b"", + ) -> AdapterResult: + """Build an error AdapterResult for news fetches.""" + return AdapterResult( + source_type="news_api", + ticker=ticker, + items=[], + raw_payload=raw, + content_hash="", + fetched_at=datetime.now(timezone.utc), + error=error, + http_status=http_status, + response_time_ms=round(elapsed_ms, 1), + metadata={"provider": "polygon"}, + ) diff --git a/services/adapters/paper_trading.py b/services/adapters/paper_trading.py new file mode 100644 index 0000000..701ae06 --- /dev/null +++ b/services/adapters/paper_trading.py @@ -0,0 +1,603 @@ +"""Paper trading adapter - local order simulation and state sync. + +Implements a fully local paper trading engine that simulates order +execution without requiring a real broker API. Tracks positions, +account balance, fills, and order events in-memory with PostgreSQL +persistence for state sync and audit trail. + +Requirements: 8.1, 8.3, 8.5, 2.4 +Design: Section 4.9 - Broker Adapter (paper mode) +""" +from __future__ import annotations + +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any + +import asyncpg + +from services.adapters.broker_adapter import ( + AccountInfo, + BrokerDataAdapter, + OrderEventType, + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, + TradingMode, +) +from services.adapters.base import AdapterResult + +logger = logging.getLogger("paper_trading") + + +# --------------------------------------------------------------------------- +# In-memory paper trading state +# --------------------------------------------------------------------------- + + +class PaperPosition: + """Tracks a single paper position.""" + + def __init__( + self, + ticker: str, + quantity: float = 0.0, + avg_entry_price: float = 0.0, + realized_pnl: float = 0.0, + ) -> None: + self.ticker = ticker + self.quantity = quantity + self.avg_entry_price = avg_entry_price + self.realized_pnl = realized_pnl + + def apply_fill(self, side: OrderSide, fill_qty: float, fill_price: float) -> float: + """Apply a fill to this position. Returns realized PnL from the fill.""" + realized = 0.0 + + if side == OrderSide.BUY: + # Buying: average up the entry price + total_cost = self.avg_entry_price * self.quantity + fill_price * fill_qty + self.quantity += fill_qty + if self.quantity > 0: + self.avg_entry_price = total_cost / self.quantity + else: + # Selling: realize PnL on the sold shares + if self.quantity > 0: + sell_qty = min(fill_qty, self.quantity) + realized = sell_qty * (fill_price - self.avg_entry_price) + self.quantity -= sell_qty + self.realized_pnl += realized + if self.quantity <= 0: + self.quantity = 0.0 + self.avg_entry_price = 0.0 + + return realized + + @property + def is_open(self) -> bool: + return self.quantity > 0 + + def to_position_info(self, current_price: float | None = None) -> PositionInfo: + """Convert to a PositionInfo for the broker interface.""" + price = current_price if current_price is not None else self.avg_entry_price + unrealized = (price - self.avg_entry_price) * self.quantity if self.quantity > 0 else 0.0 + market_value = price * self.quantity + return PositionInfo( + ticker=self.ticker, + quantity=self.quantity, + avg_entry_price=self.avg_entry_price, + current_price=price, + unrealized_pnl=round(unrealized, 4), + market_value=round(market_value, 4), + side="long" if self.quantity > 0 else "flat", + ) + + +class PaperAccount: + """In-memory paper trading account state.""" + + def __init__( + self, + account_id: str = "paper-default", + initial_cash: float = 100_000.0, + ) -> None: + self.account_id = account_id + self.initial_cash = initial_cash + self.cash = initial_cash + self.positions: dict[str, PaperPosition] = {} + self.orders: dict[str, OrderResponse] = {} + self.order_events: list[dict[str, Any]] = [] + self._seen_idempotency_keys: dict[str, str] = {} # key -> order_id + + @property + def portfolio_value(self) -> float: + position_value = sum( + p.quantity * p.avg_entry_price for p in self.positions.values() if p.is_open + ) + return self.cash + position_value + + @property + def buying_power(self) -> float: + return self.cash + + def get_position(self, ticker: str) -> PaperPosition: + if ticker not in self.positions: + self.positions[ticker] = PaperPosition(ticker=ticker) + return self.positions[ticker] + + def to_account_info(self) -> AccountInfo: + return AccountInfo( + account_id=self.account_id, + buying_power=round(self.buying_power, 2), + cash=round(self.cash, 2), + portfolio_value=round(self.portfolio_value, 2), + currency="USD", + mode=TradingMode.PAPER, + ) + + +# --------------------------------------------------------------------------- +# Paper trading adapter +# --------------------------------------------------------------------------- + + +class PaperTradingAdapter(BrokerDataAdapter): + """Local paper trading adapter that simulates order execution. + + All orders are filled immediately at the estimated price (market orders) + or at the limit/stop price when applicable. No real broker API is called. + + Features: + - Idempotent order submission via idempotency_key (Req 8.5) + - Full order event trail for audit (Req 8.3) + - Position tracking with average entry price + - Cash balance management + - State sync to/from PostgreSQL + + The adapter operates in PAPER mode only and rejects any attempt + to switch to LIVE mode. + """ + + def __init__( + self, + account_id: str = "paper-default", + initial_cash: float = 100_000.0, + simulated_slippage_pct: float = 0.001, + ) -> None: + super().__init__(mode=TradingMode.PAPER) + self.account = PaperAccount(account_id=account_id, initial_cash=initial_cash) + self.slippage_pct = simulated_slippage_pct + + def source_type(self) -> str: + return "broker" + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch paper positions/account as a raw artifact snapshot.""" + endpoint = config.get("endpoint", "positions") + now = datetime.now(timezone.utc) + + if endpoint == "account": + data = self.account.to_account_info().to_dict() + items = [data] + elif endpoint == "orders": + items = [ + resp.to_dict() + for resp in self.account.orders.values() + if resp.ticker == ticker or ticker == "*" + ] + else: + pos = self.account.get_position(ticker) + data = pos.to_position_info().to_dict() + items = [data] if pos.is_open else [] + + raw = json.dumps(items).encode() + return AdapterResult( + source_type="broker", + ticker=ticker, + items=items, + raw_payload=raw, + content_hash="", + fetched_at=now, + metadata={"provider": "paper", "mode": "paper", "endpoint": endpoint}, + ) + + async def submit_order(self, order: OrderRequest) -> OrderResponse: + """Simulate order submission and immediate fill. + + Idempotency: if the same idempotency_key was already used, + return the original response (Req 8.5). + """ + # Idempotency check + existing_id = self.account._seen_idempotency_keys.get(order.idempotency_key) + if existing_id and existing_id in self.account.orders: + logger.info("Duplicate order key %s — returning cached response", order.idempotency_key) + return self.account.orders[existing_id] + + now = datetime.now(timezone.utc) + order_id = str(uuid.uuid4()) + + # Determine fill price based on order type + fill_price = self._compute_fill_price(order) + + # Check if we have enough cash for buys + if order.side == OrderSide.BUY: + required_cash = fill_price * order.quantity + if required_cash > self.account.cash: + resp = OrderResponse( + broker_order_id=order_id, + status=OrderStatus.REJECTED, + ticker=order.ticker, + side=order.side, + quantity=order.quantity, + submitted_at=now, + error=f"Insufficient cash: need {required_cash:.2f}, have {self.account.cash:.2f}", + ) + self._record_event(order_id, OrderEventType.REJECTED, resp.to_dict(), now) + self.account.orders[order_id] = resp + self.account._seen_idempotency_keys[order.idempotency_key] = order_id + return resp + + # Check if we have enough shares for sells + if order.side == OrderSide.SELL: + pos = self.account.get_position(order.ticker) + if pos.quantity < order.quantity: + resp = OrderResponse( + broker_order_id=order_id, + status=OrderStatus.REJECTED, + ticker=order.ticker, + side=order.side, + quantity=order.quantity, + submitted_at=now, + error=f"Insufficient shares: need {order.quantity}, have {pos.quantity}", + ) + self._record_event(order_id, OrderEventType.REJECTED, resp.to_dict(), now) + self.account.orders[order_id] = resp + self.account._seen_idempotency_keys[order.idempotency_key] = order_id + return resp + + # Simulate immediate fill + position = self.account.get_position(order.ticker) + realized_pnl = position.apply_fill(order.side, order.quantity, fill_price) + + # Update cash + if order.side == OrderSide.BUY: + self.account.cash -= fill_price * order.quantity + else: + self.account.cash += fill_price * order.quantity + + resp = OrderResponse( + broker_order_id=order_id, + status=OrderStatus.FILLED, + ticker=order.ticker, + side=order.side, + quantity=order.quantity, + filled_quantity=order.quantity, + filled_avg_price=fill_price, + submitted_at=now, + raw_response={ + "realized_pnl": round(realized_pnl, 4), + "cash_after": round(self.account.cash, 2), + "position_qty_after": position.quantity, + "simulated": True, + }, + ) + + # Record events + self._record_event(order_id, OrderEventType.SUBMITTED, {"ticker": order.ticker}, now) + self._record_event(order_id, OrderEventType.ACCEPTED, {"ticker": order.ticker}, now) + self._record_event(order_id, OrderEventType.FILL, { + "fill_price": fill_price, + "fill_qty": order.quantity, + "realized_pnl": round(realized_pnl, 4), + }, now) + + self.account.orders[order_id] = resp + self.account._seen_idempotency_keys[order.idempotency_key] = order_id + + logger.info( + "Paper fill: %s %s %.0f %s @ %.2f | cash=%.2f pnl=%.4f", + order_id[:8], order.side.value, order.quantity, + order.ticker, fill_price, self.account.cash, realized_pnl, + ) + + return resp + + async def cancel_order(self, broker_order_id: str) -> OrderResponse: + """Cancel a paper order. Only pending orders can be cancelled.""" + existing = self.account.orders.get(broker_order_id) + if existing is None: + return OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.REJECTED, + ticker="", + side=OrderSide.BUY, + quantity=0, + error=f"Order {broker_order_id} not found", + ) + + # Paper orders fill immediately, so they can't be cancelled + if existing.status == OrderStatus.FILLED: + return OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.REJECTED, + ticker=existing.ticker, + side=existing.side, + quantity=existing.quantity, + error="Cannot cancel a filled order", + ) + + now = datetime.now(timezone.utc) + cancelled = OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.CANCELLED, + ticker=existing.ticker, + side=existing.side, + quantity=existing.quantity, + submitted_at=existing.submitted_at, + ) + self.account.orders[broker_order_id] = cancelled + self._record_event(broker_order_id, OrderEventType.CANCELLED, {}, now) + return cancelled + + async def get_order_status(self, broker_order_id: str) -> OrderResponse: + """Get the status of a paper order.""" + existing = self.account.orders.get(broker_order_id) + if existing is None: + return OrderResponse( + broker_order_id=broker_order_id, + status=OrderStatus.REJECTED, + ticker="", + side=OrderSide.BUY, + quantity=0, + error=f"Order {broker_order_id} not found", + ) + return existing + + async def get_positions(self) -> list[PositionInfo]: + """Get all open paper positions.""" + return [ + p.to_position_info() + for p in self.account.positions.values() + if p.is_open + ] + + async def get_account(self) -> AccountInfo: + """Get paper account summary.""" + return self.account.to_account_info() + + # ----------------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------------- + + def _compute_fill_price(self, order: OrderRequest) -> float: + """Determine the simulated fill price for an order. + + Market orders use the limit_price as a proxy (or 0 if not set). + Limit orders fill at the limit price. + Stop orders fill at the stop price. + A small slippage is applied to market orders. + """ + if order.order_type == OrderType.LIMIT and order.limit_price is not None: + return order.limit_price + if order.order_type == OrderType.STOP and order.stop_price is not None: + return order.stop_price + if order.order_type == OrderType.STOP_LIMIT and order.limit_price is not None: + return order.limit_price + + # Market order: use limit_price as estimate, or a default + base_price = order.limit_price if order.limit_price is not None else 100.0 + if order.side == OrderSide.BUY: + return round(base_price * (1 + self.slippage_pct), 4) + return round(base_price * (1 - self.slippage_pct), 4) + + def _record_event( + self, + order_id: str, + event_type: OrderEventType, + data: dict[str, Any], + timestamp: datetime, + ) -> None: + """Record an order event for audit trail.""" + self.account.order_events.append({ + "order_id": order_id, + "event_type": event_type.value, + "data": data, + "timestamp": timestamp.isoformat(), + }) + + +# --------------------------------------------------------------------------- +# State sync: persist and restore paper trading state to/from PostgreSQL +# --------------------------------------------------------------------------- + +# SQL for persisting paper orders to the orders table +_INSERT_PAPER_ORDER = """ +INSERT INTO orders ( + id, recommendation_id, broker_account_id, ticker, side, order_type, + quantity, limit_price, stop_price, status, idempotency_key, + broker_order_id, decision_trace, submitted_at, filled_at, + fill_price, fill_quantity +) VALUES ( + $1::uuid, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, + $12, $13::jsonb, $14, $15, + $16, $17 +) +ON CONFLICT (idempotency_key) DO NOTHING +""" + +_INSERT_PAPER_ORDER_EVENT = """ +INSERT INTO order_events (order_id, event_type, data, broker_timestamp) +VALUES ($1::uuid, $2, $3::jsonb, $4) +""" + +_UPSERT_PAPER_POSITION = """ +INSERT INTO positions (broker_account_id, ticker, quantity, avg_entry_price, realized_pnl, updated_at) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (broker_account_id, ticker) + DO UPDATE SET + quantity = EXCLUDED.quantity, + avg_entry_price = EXCLUDED.avg_entry_price, + realized_pnl = EXCLUDED.realized_pnl, + updated_at = EXCLUDED.updated_at +""" + +_UPSERT_PAPER_ACCOUNT = """ +INSERT INTO broker_accounts (id, provider, account_id, mode, config, active) +VALUES ($1::uuid, 'paper', $2, 'paper', $3::jsonb, TRUE) +ON CONFLICT (id) DO UPDATE SET + config = EXCLUDED.config, + active = TRUE +""" + +_LOAD_PAPER_POSITIONS = """ +SELECT ticker, quantity, avg_entry_price, COALESCE(realized_pnl, 0) AS realized_pnl +FROM positions +WHERE broker_account_id = $1 AND quantity > 0 +""" + +_LOAD_PAPER_ACCOUNT_CONFIG = """ +SELECT config FROM broker_accounts +WHERE account_id = $1 AND mode = 'paper' AND active = TRUE +LIMIT 1 +""" + +_LOAD_PAPER_ORDERS = """ +SELECT + id, ticker, side, order_type, quantity, status, + idempotency_key, broker_order_id, fill_price, fill_quantity, + submitted_at +FROM orders +WHERE broker_account_id = ( + SELECT id FROM broker_accounts WHERE account_id = $1 AND mode = 'paper' LIMIT 1 +) +ORDER BY submitted_at DESC +LIMIT 500 +""" + + +async def sync_state_to_db( + adapter: PaperTradingAdapter, + pool: asyncpg.Pool, + broker_account_uuid: str | None = None, +) -> None: + """Persist the current paper trading state to PostgreSQL. + + Writes: + - broker_accounts row for the paper account + - positions rows for all open positions + - orders rows for all orders (idempotent via ON CONFLICT) + - order_events for audit trail + + This enables state recovery after restarts and provides the + full execution audit trail (Requirement 8.3). + """ + acct = adapter.account + now = datetime.now(timezone.utc) + acct_uuid = broker_account_uuid or str(uuid.uuid5(uuid.NAMESPACE_DNS, acct.account_id)) + + async with pool.acquire() as conn: + async with conn.transaction(): + # 1. Upsert broker account + config_json = json.dumps({ + "initial_cash": acct.initial_cash, + "current_cash": round(acct.cash, 2), + "portfolio_value": round(acct.portfolio_value, 2), + "slippage_pct": adapter.slippage_pct, + }) + await conn.execute(_UPSERT_PAPER_ACCOUNT, acct_uuid, acct.account_id, config_json) + + # 2. Upsert positions + for ticker, pos in acct.positions.items(): + await conn.execute( + _UPSERT_PAPER_POSITION, + acct_uuid, ticker, + pos.quantity, pos.avg_entry_price, pos.realized_pnl, + now, + ) + + # 3. Insert orders (idempotent) + for order_id, resp in acct.orders.items(): + filled_at = now if resp.status == OrderStatus.FILLED else None + await conn.execute( + _INSERT_PAPER_ORDER, + order_id, + None, # recommendation_id + acct_uuid, + resp.ticker, + resp.side.value, + "market", # paper orders are always market-simulated + resp.quantity, + resp.filled_avg_price, # limit_price + None, # stop_price + resp.status.value, + order_id, # use order_id as idempotency_key fallback + order_id, + json.dumps(resp.raw_response), + resp.submitted_at, + filled_at, + resp.filled_avg_price, + resp.filled_quantity, + ) + + # 4. Insert order events + for event in acct.order_events: + await conn.execute( + _INSERT_PAPER_ORDER_EVENT, + event["order_id"], + event["event_type"], + json.dumps(event["data"]), + datetime.fromisoformat(event["timestamp"]), + ) + + logger.info( + "Synced paper state to DB: account=%s positions=%d orders=%d events=%d", + acct.account_id, len(acct.positions), len(acct.orders), len(acct.order_events), + ) + + # Clear events after sync to avoid re-inserting + acct.order_events.clear() + + +async def load_state_from_db( + adapter: PaperTradingAdapter, + pool: asyncpg.Pool, +) -> bool: + """Restore paper trading state from PostgreSQL. + + Loads positions and account config from the DB so the adapter + can resume after a restart. Returns True if state was found. + """ + acct = adapter.account + + async with pool.acquire() as conn: + # Load account config + row = await conn.fetchrow(_LOAD_PAPER_ACCOUNT_CONFIG, acct.account_id) + if row is None: + logger.info("No saved paper account state for %s", acct.account_id) + return False + + config = json.loads(row["config"]) if isinstance(row["config"], str) else row["config"] + acct.cash = float(config.get("current_cash", acct.initial_cash)) + + # Load positions + pos_rows = await conn.fetch(_LOAD_PAPER_POSITIONS, acct.account_id) + for pr in pos_rows: + ticker = pr["ticker"] + acct.positions[ticker] = PaperPosition( + ticker=ticker, + quantity=float(pr["quantity"]), + avg_entry_price=float(pr["avg_entry_price"] or 0), + realized_pnl=float(pr["realized_pnl"]), + ) + + logger.info( + "Loaded paper state from DB: account=%s cash=%.2f positions=%d", + acct.account_id, acct.cash, len(acct.positions), + ) + return True diff --git a/services/adapters/resilient.py b/services/adapters/resilient.py new file mode 100644 index 0000000..4977324 --- /dev/null +++ b/services/adapters/resilient.py @@ -0,0 +1,241 @@ +"""Resilient adapter wrapper with rate-limit coordination, retries, and backoff. + +Wraps any BaseAdapter with: +- Per-source-type rate limiting via Redis (distributed across workers) +- Exponential backoff with jitter on retryable failures +- Configurable retry counts and retryable HTTP status codes +- Graceful degradation when Redis is unavailable + +Requirements: 2.5, 3.4 +""" +import asyncio +import logging +import random +import time +from dataclasses import dataclass +from typing import Any + +import redis.asyncio as aioredis + +from services.shared.redis_keys import rate_limit_key + +from .base import AdapterResult, BaseAdapter + +logger = logging.getLogger("resilient_adapter") + +# HTTP status codes that are safe to retry +RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) + + +@dataclass +class RetryConfig: + """Configuration for retry and rate-limit behavior.""" + + max_retries: int = 3 + base_delay: float = 1.0 + max_delay: float = 60.0 + jitter_factor: float = 0.5 + retryable_status_codes: frozenset[int] = RETRYABLE_STATUS_CODES + # Rate limit: max requests per window per source type + rate_limit_max: int = 30 + rate_limit_window_seconds: int = 60 + + +# Sensible defaults per source type +DEFAULT_RETRY_CONFIGS: dict[str, RetryConfig] = { + "market_api": RetryConfig(max_retries=3, rate_limit_max=30), + "news_api": RetryConfig(max_retries=3, rate_limit_max=20), + "filings_api": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0), + "web_scrape": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0), + "broker": RetryConfig(max_retries=2, rate_limit_max=60, base_delay=0.5), +} + + +def compute_delay(attempt: int, config: RetryConfig) -> float: + """Compute backoff delay with jitter for a given attempt number.""" + exp_delay = config.base_delay * (2 ** attempt) + capped = min(exp_delay, config.max_delay) + jitter = capped * config.jitter_factor * random.random() + return capped + jitter + + + +@dataclass +class RetryStats: + """Tracks retry statistics for observability.""" + + attempts: int = 0 + total_delay: float = 0.0 + rate_limited_waits: int = 0 + last_error: str | None = None + retryable: bool = False + + +class ResilientAdapter: + """Wraps a BaseAdapter with rate-limit coordination, retries, and backoff. + + Usage: + adapter = PolygonMarketAdapter(api_key="...") + resilient = ResilientAdapter(adapter, redis=rds) + result = await resilient.fetch(ticker, config) + + If redis is None, rate limiting is skipped (local dev / testing). + """ + + def __init__( + self, + adapter: BaseAdapter, + redis: aioredis.Redis | None = None, + retry_config: RetryConfig | None = None, + ) -> None: + self._adapter = adapter + self._redis = redis + source_type = adapter.source_type() + self._config = retry_config or DEFAULT_RETRY_CONFIGS.get( + source_type, RetryConfig() + ) + + @property + def adapter(self) -> BaseAdapter: + """Access the underlying adapter.""" + return self._adapter + + @property + def config(self) -> RetryConfig: + return self._config + + def source_type(self) -> str: + return self._adapter.source_type() + + async def _check_rate_limit(self) -> float: + """Check distributed rate limit via Redis. + + Returns 0.0 if allowed, or the number of seconds to wait. + """ + if self._redis is None: + return 0.0 + + source_type = self._adapter.source_type() + window_sec = self._config.rate_limit_window_seconds + # Use a time-bucketed key so counters auto-expire + bucket = int(time.time()) // window_sec + key = rate_limit_key(source_type, str(bucket)) + + try: + count = await self._redis.incr(key) + if count == 1: + await self._redis.expire(key, window_sec * 2) + if count > self._config.rate_limit_max: + # Over limit — compute how long until the window rolls over + elapsed_in_window = time.time() % window_sec + wait = window_sec - elapsed_in_window + return max(wait, 0.5) + except Exception: + # Redis unavailable — degrade gracefully, allow the request + logger.warning("Redis rate-limit check failed, allowing request") + return 0.0 + + def _is_retryable(self, result: AdapterResult) -> bool: + """Determine if a failed result is worth retrying.""" + if result.ok: + return False + # Retry on known retryable HTTP status codes + if result.http_status and result.http_status in self._config.retryable_status_codes: + return True + # Retry on timeouts + if result.error and "timeout" in result.error.lower(): + return True + # Retry on connection errors + if result.error and any( + kw in result.error.lower() + for kw in ("connection", "connect", "reset", "refused") + ): + return True + return False + + def _extract_retry_after(self, result: AdapterResult) -> float | None: + """Extract Retry-After hint from result metadata if present.""" + retry_after = result.metadata.get("retry_after") + if retry_after is not None: + try: + return float(retry_after) + except (ValueError, TypeError): + pass + return None + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch with rate-limit coordination, retries, and exponential backoff. + + Returns the AdapterResult from the underlying adapter. On retryable + failures, retries up to max_retries times with exponential backoff + and jitter. Rate-limit waits are applied before each attempt. + + The returned result's metadata includes retry stats under the + "retry_stats" key. + """ + stats = RetryStats() + last_result: AdapterResult | None = None + + for attempt in range(self._config.max_retries + 1): + stats.attempts = attempt + 1 + + # Rate limit check + wait = await self._check_rate_limit() + if wait > 0: + stats.rate_limited_waits += 1 + logger.info( + "Rate limited for %s/%s, waiting %.1fs", + self.source_type(), ticker, wait, + ) + stats.total_delay += wait + await asyncio.sleep(wait) + + # Execute the fetch + result = await self._adapter.fetch(ticker, config) + last_result = result + + # Success — attach stats and return + if result.ok: + result.metadata["retry_stats"] = { + "attempts": stats.attempts, + "total_delay": round(stats.total_delay, 2), + "rate_limited_waits": stats.rate_limited_waits, + } + return result + + # Check if retryable + if not self._is_retryable(result): + stats.last_error = result.error + stats.retryable = False + break + + stats.retryable = True + stats.last_error = result.error + + # Don't sleep after the last attempt + if attempt < self._config.max_retries: + # Respect Retry-After header for 429s + retry_after = self._extract_retry_after(result) + if result.http_status == 429 and retry_after is not None: + delay = min(retry_after, self._config.max_delay) + else: + delay = compute_delay(attempt, self._config) + + logger.info( + "Retrying %s/%s (attempt %d/%d) after %.1fs: %s", + self.source_type(), ticker, attempt + 1, + self._config.max_retries + 1, delay, result.error, + ) + stats.total_delay += delay + await asyncio.sleep(delay) + + # All retries exhausted — return last result with stats + assert last_result is not None + last_result.metadata["retry_stats"] = { + "attempts": stats.attempts, + "total_delay": round(stats.total_delay, 2), + "rate_limited_waits": stats.rate_limited_waits, + "exhausted": True, + "last_error": stats.last_error, + } + return last_result diff --git a/services/adapters/web_scrape_adapter.py b/services/adapters/web_scrape_adapter.py new file mode 100644 index 0000000..1aae573 --- /dev/null +++ b/services/adapters/web_scrape_adapter.py @@ -0,0 +1,321 @@ +"""Web scrape adapter for curated URLs and article pages. + +Fetches full article HTML from curated URLs (investor relations pages, +press releases, earnings transcripts, etc.) using BeautifulSoup + requests +with retry adapters, content hashing, boilerplate awareness, and quality scoring. + +Inspired by Noctipede crawler patterns: BeautifulSoup + requests with retry +adapters, content hashing, boilerplate stripping, quality scoring. + +Requirements: 1.2, 2.5, 3.1, 3.2, 3.3, 3.4 +""" +import json +import logging +import time +from datetime import datetime, timezone +from urllib.parse import urlparse +from typing import Any + +import httpx +from bs4 import BeautifulSoup + +from services.shared.content import content_hash, normalize_url + +from .base import AdapterResult, BaseAdapter + +logger = logging.getLogger("web_scrape_adapter") + +# Default request settings +DEFAULT_TIMEOUT = 30 +DEFAULT_USER_AGENT = "StonksOracle/1.0 (+https://stonks-oracle.celestium.life)" +MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB cap + + +def extract_metadata_from_html(html: str, url: str) -> dict[str, str | None]: + """Extract title, author, publisher, published date, and links from HTML.""" + soup = BeautifulSoup(html, "html.parser") + meta: dict[str, str | None] = {} + + # Title: prefer og:title, then + og_title = soup.find("meta", property="og:title") + if og_title and og_title.get("content"): + content = og_title["content"] + meta["title"] = content.strip() if isinstance(content, str) else "" + elif soup.title and soup.title.string: + meta["title"] = soup.title.string.strip() + else: + meta["title"] = "" + + # Author + author_tag = soup.find("meta", attrs={"name": "author"}) + if author_tag and author_tag.get("content"): + content = author_tag["content"] + meta["author"] = content.strip() if isinstance(content, str) else "" + else: + meta["author"] = "" + + # Publisher: og:site_name + site_name = soup.find("meta", property="og:site_name") + if site_name and site_name.get("content"): + content = site_name["content"] + meta["publisher"] = content.strip() if isinstance(content, str) else "" + else: + meta["publisher"] = urlparse(url).hostname or "" + + # Published date: article:published_time or datePublished + pub_time = soup.find("meta", property="article:published_time") + if pub_time and pub_time.get("content"): + content = pub_time["content"] + meta["published_at"] = content.strip() if isinstance(content, str) else None + else: + # Try JSON-LD datePublished + for script in soup.find_all("script", type="application/ld+json"): + if script.string and "datePublished" in script.string: + try: + ld = json.loads(script.string) + if isinstance(ld, dict) and "datePublished" in ld: + meta["published_at"] = str(ld["datePublished"]) + break + if isinstance(ld, list): + for item in ld: + if isinstance(item, dict) and "datePublished" in item: + meta["published_at"] = str(item["datePublished"]) + break + except (json.JSONDecodeError, TypeError): + pass + if "published_at" not in meta: + meta["published_at"] = None + + # Canonical URL + canonical = soup.find("link", rel="canonical") + if canonical and canonical.get("href"): + href = canonical["href"] + meta["canonical_url"] = str(href) if href else normalize_url(url) + else: + og_url = soup.find("meta", property="og:url") + if og_url and og_url.get("content"): + content = og_url["content"] + meta["canonical_url"] = str(content) if content else normalize_url(url) + else: + meta["canonical_url"] = normalize_url(url) + + # Language + html_tag = soup.find("html") + if html_tag and html_tag.get("lang"): + lang = html_tag["lang"] + meta["language"] = str(lang)[:5] if lang else "en" + else: + meta["language"] = "en" + + # Description for summary + desc = soup.find("meta", property="og:description") or soup.find( + "meta", attrs={"name": "description"} + ) + if desc and desc.get("content"): + content = desc["content"] + meta["description"] = content.strip() if isinstance(content, str) else "" + else: + meta["description"] = "" + + return meta + + +def extract_body_text(html: str) -> str: + """Extract main body text from HTML, stripping nav/footer/ads.""" + soup = BeautifulSoup(html, "html.parser") + + # Remove non-content elements + for tag in soup.find_all( + ["script", "style", "nav", "footer", "header", "aside", "iframe", "noscript"] + ): + tag.decompose() + + # Try to find article body + article = soup.find("article") + if not article: + for div in soup.find_all("div"): + cls = div.get("class", []) + cls_str = " ".join(cls) if isinstance(cls, list) else str(cls) if cls else "" + if any(kw in cls_str for kw in ["article-body", "post-content", "entry-content", "story-body"]): + article = div + break + + if article: + text = article.get_text(separator="\n", strip=True) + else: + # Fallback: use body + body = soup.find("body") + text = body.get_text(separator="\n", strip=True) if body else soup.get_text(separator="\n", strip=True) + + # Collapse whitespace + lines = [line.strip() for line in text.splitlines() if line.strip()] + return "\n".join(lines) + + +class WebScrapeAdapter(BaseAdapter): + """Adapter for fetching curated web pages and article URLs. + + Config options (from source config): + urls: List of URLs to scrape for this company + url: Single URL to scrape (alternative to urls) + timeout: Request timeout in seconds (default 30) + user_agent: Custom user agent string + follow_links: Whether to follow article links from index pages (default False) + max_pages: Max pages to fetch per cycle (default 5) + """ + + def __init__(self) -> None: + pass + + def source_type(self) -> str: + return "web_scrape" + + def bucket_name(self) -> str: + """Web scrape artifacts go to the news raw bucket.""" + return "stonks-raw-news" + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + """Fetch HTML from curated URLs for a given ticker. + + Supports both single URL and multi-URL configs. Each URL is fetched, + HTML is preserved as raw payload, and metadata is extracted. + """ + urls = config.get("urls", []) + if not urls and config.get("url"): + urls = [config["url"]] + + if not urls: + return self._error_result(ticker, "No URLs configured for web_scrape source", 0) + + timeout = config.get("timeout", DEFAULT_TIMEOUT) + user_agent = config.get("user_agent", DEFAULT_USER_AGENT) + max_pages = min(config.get("max_pages", 5), 20) + + items: list[dict[str, Any]] = [] + all_raw: list[bytes] = [] + total_elapsed = 0.0 + errors: list[str] = [] + + async with httpx.AsyncClient( + timeout=timeout, + follow_redirects=True, + headers={"User-Agent": user_agent}, + ) as client: + for url in urls[:max_pages]: + t0 = time.monotonic() + try: + resp = await client.get(url) + elapsed_ms = (time.monotonic() - t0) * 1000 + total_elapsed += elapsed_ms + resp.raise_for_status() + + # Content length guard + if len(resp.content) > MAX_CONTENT_LENGTH: + errors.append(f"Content too large for {url}: {len(resp.content)} bytes") + continue + + html = resp.text + raw_bytes = resp.content + all_raw.append(raw_bytes) + + item_content_hash = content_hash(raw_bytes) + meta = extract_metadata_from_html(html, url) + body_text = extract_body_text(html) + + item: dict[str, Any] = { + "url": url, + "canonical_url": meta.get("canonical_url", normalize_url(url)), + "title": meta.get("title", ""), + "author": meta.get("author", ""), + "publisher": meta.get("publisher", ""), + "published_at": meta.get("published_at"), + "language": meta.get("language", "en"), + "description": meta.get("description", ""), + "content_hash": item_content_hash, + "body_text": body_text, + "body_length": len(body_text), + "html_length": len(html), + "http_status": resp.status_code, + "response_time_ms": round(elapsed_ms, 1), + } + items.append(item) + + except httpx.HTTPStatusError as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + total_elapsed += elapsed_ms + status = e.response.status_code if e.response else None + errors.append(f"HTTP {status} for {url}: {e}") + logger.warning("Scrape HTTP error for %s/%s: %s", ticker, url, e) + + except httpx.TimeoutException as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + total_elapsed += elapsed_ms + errors.append(f"Timeout for {url}: {e}") + logger.warning("Scrape timeout for %s/%s: %s", ticker, url, e) + + except Exception as e: + elapsed_ms = (time.monotonic() - t0) * 1000 + total_elapsed += elapsed_ms + errors.append(f"Error for {url}: {e}") + logger.warning("Scrape error for %s/%s: %s", ticker, url, e) + + if not items: + error_msg = "; ".join(errors) if errors else "No pages fetched" + return self._error_result(ticker, error_msg, total_elapsed) + + # Combine all raw payloads into a single artifact + combined_raw = json.dumps({ + "ticker": ticker, + "fetched_at": datetime.now(timezone.utc).isoformat(), + "pages": [ + { + "url": item["url"], + "content_hash": item["content_hash"], + "html_length": item["html_length"], + "body_length": item["body_length"], + } + for item in items + ], + "errors": errors, + }).encode("utf-8") + + combined_hash = content_hash( + b"".join(item["content_hash"].encode() for item in items) + ) + + return AdapterResult( + source_type="web_scrape", + ticker=ticker, + items=items, + raw_payload=combined_raw, + content_hash=combined_hash, + fetched_at=datetime.now(timezone.utc), + http_status=200, + response_time_ms=round(total_elapsed, 1), + metadata={ + "provider": "web_scrape", + "pages_fetched": len(items), + "pages_failed": len(errors), + "errors": errors, + }, + ) + + def _error_result( + self, + ticker: str, + error: str, + elapsed_ms: float, + ) -> AdapterResult: + """Build an error AdapterResult for scrape fetches.""" + return AdapterResult( + source_type="web_scrape", + ticker=ticker, + items=[], + raw_payload=b"", + content_hash="", + fetched_at=datetime.now(timezone.utc), + error=error, + http_status=None, + response_time_ms=round(elapsed_ms, 1), + metadata={"provider": "web_scrape"}, + ) diff --git a/services/aggregation/contradiction.py b/services/aggregation/contradiction.py new file mode 100644 index 0000000..4b41c1d --- /dev/null +++ b/services/aggregation/contradiction.py @@ -0,0 +1,169 @@ +"""Contradiction detection and disagreement representation. + +Analyses weighted signals to detect and represent disagreement explicitly, +rather than collapsing contradictory evidence into a single unsupported +conclusion. + +Requirements: 6.4, 6.5 +""" +from __future__ import annotations + +from dataclasses import dataclass + +from services.aggregation.scoring import WeightedSignal +from services.shared.schemas import DisagreementDetail + + +@dataclass +class CatalystEntry: + """Lightweight carrier for per-document catalyst info needed by + contradiction detection. Avoids importing ImpactRow and creating + a circular dependency with worker.py.""" + + document_id: str + catalyst_type: str + + +@dataclass +class ContradictionResult: + """Full contradiction analysis output.""" + + score: float # 0-1, same semantics as existing compute_contradiction_score + details: list[DisagreementDetail] + + +def detect_contradictions( + signals: list[WeightedSignal], + catalyst_entries: list[CatalystEntry] | None = None, +) -> ContradictionResult: + """Run contradiction detection across multiple dimensions. + + Analyses: + 1. Sentiment disagreement — the core positive-vs-negative split + 2. Catalyst disagreement — same catalyst type with opposing sentiment + + Returns a ContradictionResult with an overall score and per-dimension + disagreement details. + """ + details: list[DisagreementDetail] = [] + + sentiment_detail = _detect_sentiment_disagreement(signals) + if sentiment_detail is not None: + details.append(sentiment_detail) + + if catalyst_entries: + catalyst_details = _detect_catalyst_disagreement(signals, catalyst_entries) + details.extend(catalyst_details) + + score = _compute_overall_score(signals) + + return ContradictionResult(score=score, details=details) + + +def _compute_overall_score(signals: list[WeightedSignal]) -> float: + """Minority/majority weight ratio — backward-compatible formula.""" + if not signals: + return 0.0 + + pos_weight = 0.0 + neg_weight = 0.0 + for sig in signals: + w = sig.weight.combined * sig.impact_score + if sig.sentiment_value > 0: + pos_weight += w + elif sig.sentiment_value < 0: + neg_weight += w + + total = pos_weight + neg_weight + if total == 0.0: + return 0.0 + + minority = min(pos_weight, neg_weight) + return round(minority / total, 4) + + +def _detect_sentiment_disagreement( + signals: list[WeightedSignal], +) -> DisagreementDetail | None: + """Detect when both positive and negative sentiment signals exist.""" + pos_ids: list[str] = [] + neg_ids: list[str] = [] + pos_weight = 0.0 + neg_weight = 0.0 + + for sig in signals: + w = sig.weight.combined * sig.impact_score + if w <= 0: + continue + if sig.sentiment_value > 0: + pos_ids.append(sig.document_id) + pos_weight += w + elif sig.sentiment_value < 0: + neg_ids.append(sig.document_id) + neg_weight += w + + if not pos_ids or not neg_ids: + return None + + total = pos_weight + neg_weight + minority_pct = min(pos_weight, neg_weight) / total if total > 0 else 0.0 + + return DisagreementDetail( + dimension="sentiment", + positive_doc_ids=pos_ids, + negative_doc_ids=neg_ids, + positive_weight=round(pos_weight, 4), + negative_weight=round(neg_weight, 4), + description=( + f"Sentiment split: {len(pos_ids)} positive vs {len(neg_ids)} negative signals " + f"(minority weight ratio {minority_pct:.0%})" + ), + ) + + +def _detect_catalyst_disagreement( + signals: list[WeightedSignal], + catalyst_entries: list[CatalystEntry], +) -> list[DisagreementDetail]: + """Detect when the same catalyst type has both positive and negative signals.""" + # Build lookup: document_id → (sentiment_value, combined_weight) + sig_lookup: dict[str, tuple[float, float]] = {} + for sig in signals: + w = sig.weight.combined * sig.impact_score + if w > 0: + sig_lookup[sig.document_id] = (sig.sentiment_value, w) + + # Group by catalyst type + from collections import defaultdict + catalyst_groups: dict[str, list[tuple[str, float, float]]] = defaultdict(list) + for entry in catalyst_entries: + if entry.document_id in sig_lookup: + sent_val, weight = sig_lookup[entry.document_id] + if sent_val != 0.0: + catalyst_groups[entry.catalyst_type].append( + (entry.document_id, sent_val, weight) + ) + + details: list[DisagreementDetail] = [] + for catalyst, entries in catalyst_groups.items(): + pos_ids = [doc_id for doc_id, sv, _ in entries if sv > 0] + neg_ids = [doc_id for doc_id, sv, _ in entries if sv < 0] + if not pos_ids or not neg_ids: + continue + + pos_w = sum(w for _, sv, w in entries if sv > 0) + neg_w = sum(w for _, sv, w in entries if sv < 0) + + details.append(DisagreementDetail( + dimension=f"catalyst:{catalyst}", + positive_doc_ids=pos_ids, + negative_doc_ids=neg_ids, + positive_weight=round(pos_w, 4), + negative_weight=round(neg_w, 4), + description=( + f"Catalyst '{catalyst}' has {len(pos_ids)} positive and " + f"{len(neg_ids)} negative signals" + ), + )) + + return details diff --git a/services/aggregation/evidence.py b/services/aggregation/evidence.py new file mode 100644 index 0000000..0f16f5f --- /dev/null +++ b/services/aggregation/evidence.py @@ -0,0 +1,141 @@ +"""Evidence ranking for supporting and opposing documents. + +Ranks document signals by a composite score that considers multiple +factors beyond raw weight, producing explainable evidence lists for +trend summaries. + +Requirements: 6.5 +""" +from __future__ import annotations + +from dataclasses import dataclass + +from services.aggregation.scoring import WeightedSignal + + +@dataclass(frozen=True) +class EvidenceRankConfig: + """Weights for the composite evidence ranking score.""" + + # How much the combined signal weight matters (recency * credibility * novelty * market) + weight_factor: float = 0.40 + # How much the document's impact score matters + impact_factor: float = 0.30 + # How much recency alone matters (favours fresh evidence in the ranking) + recency_factor: float = 0.20 + # How much extraction confidence matters + confidence_factor: float = 0.10 + # Maximum evidence refs per side (supporting / opposing) + max_refs: int = 10 + + +DEFAULT_RANK_CONFIG = EvidenceRankConfig() + + +@dataclass +class RankedEvidence: + """A document with its composite ranking score and breakdown.""" + + document_id: str + rank_score: float + weight_component: float + impact_component: float + recency_component: float + confidence_component: float + sentiment_value: float # +1 / -1 / 0 + + +def compute_evidence_rank( + signal: WeightedSignal, + config: EvidenceRankConfig = DEFAULT_RANK_CONFIG, +) -> RankedEvidence: + """Compute a composite ranking score for a single signal. + + The score blends: + - combined signal weight (captures recency decay, credibility, novelty, market ctx) + - raw impact score + - recency weight alone (extra boost for freshness in the ranking) + - extraction confidence (via the credibility component of the weight) + + All components are in [0, 1] so the composite is bounded by the sum + of the factor weights. + """ + w = signal.weight + + weight_component = w.combined * config.weight_factor + impact_component = signal.impact_score * config.impact_factor + recency_component = w.recency * config.recency_factor + confidence_component = w.credibility * config.confidence_factor + + rank_score = weight_component + impact_component + recency_component + confidence_component + + return RankedEvidence( + document_id=signal.document_id, + rank_score=round(rank_score, 6), + weight_component=round(weight_component, 6), + impact_component=round(impact_component, 6), + recency_component=round(recency_component, 6), + confidence_component=round(confidence_component, 6), + sentiment_value=signal.sentiment_value, + ) + + +def rank_evidence( + signals: list[WeightedSignal], + config: EvidenceRankConfig = DEFAULT_RANK_CONFIG, +) -> tuple[list[str], list[str]]: + """Rank signals into top supporting and opposing document ID lists. + + Supporting = positive sentiment, Opposing = negative sentiment. + Neutral/mixed signals are excluded. + + Returns (supporting_ids, opposing_ids) each capped at config.max_refs. + """ + supporting: list[RankedEvidence] = [] + opposing: list[RankedEvidence] = [] + + for sig in signals: + if sig.sentiment_value == 0.0: + continue + ranked = compute_evidence_rank(sig, config) + if sig.sentiment_value > 0: + supporting.append(ranked) + else: + opposing.append(ranked) + + supporting.sort(key=lambda r: r.rank_score, reverse=True) + opposing.sort(key=lambda r: r.rank_score, reverse=True) + + return ( + [r.document_id for r in supporting[: config.max_refs]], + [r.document_id for r in opposing[: config.max_refs]], + ) + + +def rank_evidence_detailed( + signals: list[WeightedSignal], + config: EvidenceRankConfig = DEFAULT_RANK_CONFIG, +) -> tuple[list[RankedEvidence], list[RankedEvidence]]: + """Like rank_evidence but returns full RankedEvidence objects. + + Useful when callers need the score breakdown for explainability. + """ + supporting: list[RankedEvidence] = [] + opposing: list[RankedEvidence] = [] + + for sig in signals: + if sig.sentiment_value == 0.0: + continue + ranked = compute_evidence_rank(sig, config) + if sig.sentiment_value > 0: + supporting.append(ranked) + else: + opposing.append(ranked) + + supporting.sort(key=lambda r: r.rank_score, reverse=True) + opposing.sort(key=lambda r: r.rank_score, reverse=True) + + return ( + supporting[: config.max_refs], + opposing[: config.max_refs], + ) diff --git a/services/aggregation/main.py b/services/aggregation/main.py new file mode 100644 index 0000000..522862b --- /dev/null +++ b/services/aggregation/main.py @@ -0,0 +1,57 @@ +"""Aggregation worker entrypoint - polls Redis for aggregation jobs.""" +from __future__ import annotations + +import asyncio +import json +import logging + +import asyncpg + +from services.aggregation.worker import aggregate_company +from services.shared.config import load_config +from services.shared.logging import setup_logging +from services.shared.redis_keys import QUEUE_AGGREGATION, queue_key + +logger = logging.getLogger("aggregation_main") + + +async def main() -> None: + config = load_config() + setup_logging("aggregation", level=config.log_level, json_output=config.json_logs) + + pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) + + import redis.asyncio as aioredis + + redis_client = aioredis.from_url(config.redis.url) + queue = queue_key(QUEUE_AGGREGATION) + logger.info("Aggregation worker started, polling %s", queue) + + try: + while True: + raw = await redis_client.lpop(queue) + if raw is None: + await asyncio.sleep(1) + continue + + payload = raw + job = json.loads(payload) + ticker = job.get("ticker", "") + + logger.info("Processing aggregation job for %s", ticker) + + try: + summaries = await aggregate_company(pool, ticker) + logger.info( + "Aggregation complete for %s: %d windows", + ticker, len(summaries), + ) + except Exception: + logger.exception("Aggregation failed for %s", ticker) + finally: + await pool.close() + await redis_client.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/services/aggregation/market_context.py b/services/aggregation/market_context.py new file mode 100644 index 0000000..86b32d6 --- /dev/null +++ b/services/aggregation/market_context.py @@ -0,0 +1,150 @@ +"""Market context feature computation for aggregation windows. + +Fetches recent market snapshots from PostgreSQL and computes context +features (price change, volume trend, volatility) that enrich trend +summaries and modulate signal weighting. + +Requirements: 6.1, 6.2 +""" +from __future__ import annotations + +import math +from datetime import datetime, timedelta, timezone +from typing import Any + +import asyncpg + +from services.shared.schemas import MarketContext, TrendWindow + +# Map TrendWindow values to lookback durations in days. +WINDOW_LOOKBACK_DAYS: dict[str, int] = { + TrendWindow.INTRADAY.value: 1, + TrendWindow.ONE_DAY.value: 2, + TrendWindow.SEVEN_DAY.value: 8, + TrendWindow.THIRTY_DAY.value: 35, + TrendWindow.NINETY_DAY.value: 95, +} + + +async def fetch_market_context( + pool: asyncpg.Pool, + ticker: str, + window: str, + reference_time: datetime | None = None, +) -> MarketContext: + """Build a MarketContext for *ticker* over the given trend *window*. + + Queries the ``market_snapshots`` table for recent bars and computes: + - price_change_pct: (last_close - first_close) / first_close + - avg_volume: mean volume across bars + - volume_change_pct: second-half avg volume vs first-half avg volume + - volatility: std-dev of close prices + - latest_close / latest_bar_at + + Returns a MarketContext with ``bars_available == 0`` when no data exists. + """ + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + lookback_days = WINDOW_LOOKBACK_DAYS.get(window, 8) + start = reference_time - timedelta(days=lookback_days) + + rows = await pool.fetch( + """ + SELECT data, captured_at + FROM market_snapshots + WHERE ticker = $1 + AND captured_at >= $2 + AND captured_at <= $3 + ORDER BY captured_at ASC + """, + ticker, + start, + reference_time, + ) + + if not rows: + return MarketContext(ticker=ticker) + + bars = _extract_bars(rows) + if not bars: + return MarketContext(ticker=ticker) + + return _compute_context(ticker, bars) + + +def _extract_bars(rows: list[Any]) -> list[dict[str, Any]]: + """Extract OHLCV bar dicts from market_snapshot rows. + + The ``data`` column is JSONB. Polygon prev-day bars store fields like + ``o``, ``h``, ``l``, ``c``, ``v``, ``t``. We normalise to a common + dict with ``close``, ``volume``, ``captured_at``. + """ + bars: list[dict[str, Any]] = [] + for row in rows: + data = row["data"] + if isinstance(data, str): + import json + data = json.loads(data) + + # Polygon-style single bar or list of bars + items = data if isinstance(data, list) else [data] + for item in items: + close = item.get("c") or item.get("close") + volume = item.get("v") or item.get("volume") + if close is not None: + bars.append({ + "close": float(close), + "volume": float(volume) if volume is not None else 0.0, + "captured_at": row["captured_at"], + }) + return bars + + +def _compute_context(ticker: str, bars: list[dict[str, Any]]) -> MarketContext: + """Derive market context features from a sorted list of bar dicts.""" + closes = [b["close"] for b in bars] + volumes = [b["volume"] for b in bars] + + first_close = closes[0] + last_close = closes[-1] + + price_change_pct = ( + ((last_close - first_close) / first_close * 100.0) + if first_close != 0 + else 0.0 + ) + + avg_volume = sum(volumes) / len(volumes) if volumes else 0.0 + + # Volume trend: compare second half to first half + mid = len(volumes) // 2 + if mid > 0: + first_half_avg = sum(volumes[:mid]) / mid + second_half_avg = sum(volumes[mid:]) / len(volumes[mid:]) + volume_change_pct = ( + ((second_half_avg - first_half_avg) / first_half_avg * 100.0) + if first_half_avg > 0 + else 0.0 + ) + else: + volume_change_pct = 0.0 + + # Volatility: std dev of closes + if len(closes) > 1: + mean_close = sum(closes) / len(closes) + variance = sum((c - mean_close) ** 2 for c in closes) / len(closes) + volatility = math.sqrt(variance) + else: + volatility = 0.0 + + return MarketContext( + ticker=ticker, + price_change_pct=round(price_change_pct, 4), + avg_volume=round(avg_volume, 2), + volume_change_pct=round(volume_change_pct, 4), + volatility=round(volatility, 6), + latest_close=last_close, + latest_bar_at=bars[-1]["captured_at"], + bars_available=len(bars), + ) diff --git a/services/aggregation/rollups.py b/services/aggregation/rollups.py new file mode 100644 index 0000000..b983244 --- /dev/null +++ b/services/aggregation/rollups.py @@ -0,0 +1,439 @@ +"""Sector and market-level rollup aggregation. + +Aggregates company-level trend summaries into sector and market-level +summaries, enabling top-down views of sentiment and risk across the +portfolio. + +Requirements: 6.3, 6.4, 6.5 +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone + +import asyncpg + +from services.shared.schemas import ( + DisagreementDetail, + TrendDirection, + TrendSummary, + TrendWindow, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class CompanyTrendRow: + """A company-level trend summary fetched from the DB for rollup.""" + + entity_id: str # ticker + sector: str + window: str + trend_direction: str + trend_strength: float + confidence: float + contradiction_score: float + dominant_catalysts: list[str] + material_risks: list[str] + top_supporting_evidence: list[str] + top_opposing_evidence: list[str] + + +# --------------------------------------------------------------------------- +# Fetch latest company trends for a given window +# --------------------------------------------------------------------------- + +_LATEST_COMPANY_TRENDS_QUERY = """ +SELECT DISTINCT ON (tw.entity_id) + tw.entity_id, + c.sector, + tw.window, + tw.trend_direction, + tw.trend_strength, + tw.confidence, + tw.contradiction_score, + tw.dominant_catalysts, + tw.material_risks, + tw.top_supporting_evidence, + tw.top_opposing_evidence +FROM trend_windows tw +JOIN companies c ON c.ticker = tw.entity_id AND c.active = TRUE +WHERE tw.entity_type = 'company' + AND tw.window = $1 + AND tw.generated_at >= $2 +ORDER BY tw.entity_id, tw.generated_at DESC +""" + + +def _parse_jsonb_list(val: object) -> list[str]: + """Safely parse a JSONB column that should be a list of strings.""" + if isinstance(val, list): + return [str(v) for v in val] + if isinstance(val, str): + parsed = json.loads(val) + if isinstance(parsed, list): + return [str(v) for v in parsed] + return [] + + +def _parse_company_trend_row(row: object) -> CompanyTrendRow: + """Convert an asyncpg Record to a CompanyTrendRow.""" + # asyncpg Records support dict() but aren't typed; use getattr-style access + get = getattr(row, "__getitem__", None) + if get is None: + raise TypeError(f"Expected a mapping-like row, got {type(row)}") + + def _str(key: str, default: str = "") -> str: + val = get(key) + return str(val) if val is not None else default + + def _float(key: str) -> float: + val = get(key) + return float(val) if val is not None else 0.0 + + return CompanyTrendRow( + entity_id=_str("entity_id"), + sector=_str("sector", "Unknown") or "Unknown", + window=_str("window"), + trend_direction=_str("trend_direction"), + trend_strength=_float("trend_strength"), + confidence=_float("confidence"), + contradiction_score=_float("contradiction_score"), + dominant_catalysts=_parse_jsonb_list(get("dominant_catalysts")), + material_risks=_parse_jsonb_list(get("material_risks")), + top_supporting_evidence=_parse_jsonb_list(get("top_supporting_evidence")), + top_opposing_evidence=_parse_jsonb_list(get("top_opposing_evidence")), + ) + + +async def fetch_latest_company_trends( + pool: asyncpg.Pool, + window: str, + since: datetime, +) -> list[CompanyTrendRow]: + """Fetch the most recent company-level trend for each ticker in a window.""" + rows = await pool.fetch(_LATEST_COMPANY_TRENDS_QUERY, window, since) + return [_parse_company_trend_row(r) for r in rows] + + +# --------------------------------------------------------------------------- +# Pure rollup logic +# --------------------------------------------------------------------------- + +# Direction mapping for numeric aggregation +_DIRECTION_VALUES = { + TrendDirection.BULLISH.value: 1.0, + TrendDirection.BEARISH.value: -1.0, + TrendDirection.MIXED.value: 0.0, + TrendDirection.NEUTRAL.value: 0.0, +} + +BULLISH_THRESHOLD = 0.15 +BEARISH_THRESHOLD = -0.15 + + +def rollup_trends( + trends: list[CompanyTrendRow], + entity_type: str, + entity_id: str, + window: str, + reference_time: datetime, +) -> TrendSummary: + """Aggregate a list of company-level trends into a single rollup summary. + + Each company trend is weighted by its confidence to produce a + confidence-weighted average of direction, strength, and contradiction. + """ + if not trends: + return TrendSummary( + entity_type=entity_type, + entity_id=entity_id, + window=TrendWindow(window), + trend_direction=TrendDirection.NEUTRAL, + trend_strength=0.0, + confidence=0.0, + generated_at=reference_time, + ) + + total_weight = 0.0 + weighted_direction = 0.0 + weighted_strength = 0.0 + weighted_contradiction = 0.0 + catalyst_weights: dict[str, float] = {} + risk_set: dict[str, float] = {} + all_supporting: list[str] = [] + all_opposing: list[str] = [] + + for t in trends: + w = t.confidence + total_weight += w + dir_val = _DIRECTION_VALUES.get(t.trend_direction, 0.0) + weighted_direction += w * dir_val + weighted_strength += w * t.trend_strength + weighted_contradiction += w * t.contradiction_score + + for cat in t.dominant_catalysts: + catalyst_weights[cat] = catalyst_weights.get(cat, 0.0) + w + + for risk in t.material_risks: + norm = risk.strip().lower() + if norm not in risk_set: + risk_set[norm] = w + else: + risk_set[norm] = max(risk_set[norm], w) + + all_supporting.extend(t.top_supporting_evidence) + all_opposing.extend(t.top_opposing_evidence) + + if total_weight == 0.0: + return TrendSummary( + entity_type=entity_type, + entity_id=entity_id, + window=TrendWindow(window), + trend_direction=TrendDirection.NEUTRAL, + trend_strength=0.0, + confidence=0.0, + generated_at=reference_time, + ) + + avg_direction = weighted_direction / total_weight + avg_strength = weighted_strength / total_weight + avg_contradiction = weighted_contradiction / total_weight + avg_confidence = total_weight / len(trends) + + # Derive direction + direction = _derive_rollup_direction(avg_direction, avg_contradiction) + + # Top catalysts + sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True) + catalysts = [c for c, _ in sorted_catalysts[:5]] + + # Top risks (deduplicated, by weight) + sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True) + risks = [r for r, _ in sorted_risks[:5]] + + # Disagreement details + disagreement = _build_rollup_disagreement(trends, entity_id) + + return TrendSummary( + entity_type=entity_type, + entity_id=entity_id, + window=TrendWindow(window), + trend_direction=direction, + trend_strength=round(min(abs(avg_strength), 1.0), 4), + confidence=round(max(0.0, min(avg_confidence, 1.0)), 4), + top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10], + top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10], + dominant_catalysts=catalysts, + material_risks=risks, + contradiction_score=round(max(0.0, min(avg_contradiction, 1.0)), 4), + disagreement_details=disagreement, + generated_at=reference_time, + ) + + +def _derive_rollup_direction( + avg_direction: float, + avg_contradiction: float, +) -> TrendDirection: + """Map averaged direction value to a TrendDirection.""" + if avg_contradiction > 0.10 and abs(avg_direction) < 0.3: + return TrendDirection.MIXED + if avg_direction >= BULLISH_THRESHOLD: + return TrendDirection.BULLISH + if avg_direction <= BEARISH_THRESHOLD: + return TrendDirection.BEARISH + return TrendDirection.NEUTRAL + + +def _build_rollup_disagreement( + trends: list[CompanyTrendRow], + entity_id: str, +) -> list[DisagreementDetail]: + """Build disagreement details showing which companies are bullish vs bearish.""" + bullish_ids: list[str] = [] + bearish_ids: list[str] = [] + bullish_weight = 0.0 + bearish_weight = 0.0 + + for t in trends: + if t.trend_direction == TrendDirection.BULLISH.value: + bullish_ids.append(t.entity_id) + bullish_weight += t.confidence + elif t.trend_direction == TrendDirection.BEARISH.value: + bearish_ids.append(t.entity_id) + bearish_weight += t.confidence + + if not bullish_ids or not bearish_ids: + return [] + + return [ + DisagreementDetail( + dimension="company_direction", + positive_doc_ids=bullish_ids, + negative_doc_ids=bearish_ids, + positive_weight=round(bullish_weight, 4), + negative_weight=round(bearish_weight, 4), + description=( + f"{entity_id}: {len(bullish_ids)} bullish vs " + f"{len(bearish_ids)} bearish companies" + ), + ) + ] + + +# --------------------------------------------------------------------------- +# Persist rollup (reuses the same trend_windows table) +# --------------------------------------------------------------------------- + +_UPSERT_TREND = """ +INSERT INTO trend_windows ( + entity_type, entity_id, window, trend_direction, trend_strength, + confidence, top_supporting_evidence, top_opposing_evidence, + dominant_catalysts, material_risks, contradiction_score, + disagreement_details, market_context, generated_at +) VALUES ( + $1, $2, $3, $4, $5, + $6, $7::jsonb, $8::jsonb, + $9::jsonb, $10::jsonb, $11, + $12::jsonb, $13::jsonb, $14 +) +RETURNING id +""" + + +async def persist_rollup( + pool: asyncpg.Pool, + summary: TrendSummary, +) -> str: + """Insert a rollup trend summary and return its UUID.""" + row = await pool.fetchrow( + _UPSERT_TREND, + summary.entity_type, + summary.entity_id, + summary.window.value, + summary.trend_direction.value, + summary.trend_strength, + summary.confidence, + json.dumps(summary.top_supporting_evidence), + json.dumps(summary.top_opposing_evidence), + json.dumps(summary.dominant_catalysts), + json.dumps(summary.material_risks), + summary.contradiction_score, + json.dumps([d.model_dump() for d in summary.disagreement_details]), + json.dumps({}), + summary.generated_at, + ) + return str(row["id"]) # type: ignore[index] + + +# --------------------------------------------------------------------------- +# High-level rollup entry points +# --------------------------------------------------------------------------- + + +async def aggregate_sector( + pool: asyncpg.Pool, + sector: str, + window: str, + reference_time: datetime | None = None, + since: datetime | None = None, +) -> TrendSummary: + """Compute and persist a sector-level rollup for one window. + + Fetches the latest company trends, filters to the given sector, + and rolls them up into a single sector summary. + """ + if reference_time is None: + reference_time = datetime.now(timezone.utc) + if since is None: + since = reference_time - _window_lookback(window) + + all_trends = await fetch_latest_company_trends(pool, window, since) + sector_trends = [t for t in all_trends if t.sector == sector] + + summary = rollup_trends(sector_trends, "sector", sector, window, reference_time) + + if sector_trends: + rollup_id = await persist_rollup(pool, summary) + logger.info( + "Persisted sector rollup %s for %s/%s: direction=%s strength=%.3f companies=%d", + rollup_id, sector, window, summary.trend_direction.value, + summary.trend_strength, len(sector_trends), + ) + + return summary + + +async def aggregate_market( + pool: asyncpg.Pool, + window: str, + reference_time: datetime | None = None, + since: datetime | None = None, +) -> TrendSummary: + """Compute and persist a market-wide rollup for one window. + + Aggregates all company trends regardless of sector. + """ + if reference_time is None: + reference_time = datetime.now(timezone.utc) + if since is None: + since = reference_time - _window_lookback(window) + + all_trends = await fetch_latest_company_trends(pool, window, since) + + summary = rollup_trends(all_trends, "market", "all", window, reference_time) + + if all_trends: + rollup_id = await persist_rollup(pool, summary) + logger.info( + "Persisted market rollup %s for %s: direction=%s strength=%.3f companies=%d", + rollup_id, window, summary.trend_direction.value, + summary.trend_strength, len(all_trends), + ) + + return summary + + +async def aggregate_all_sectors( + pool: asyncpg.Pool, + window: str, + reference_time: datetime | None = None, + since: datetime | None = None, +) -> list[TrendSummary]: + """Compute sector rollups for every sector that has company trends.""" + if reference_time is None: + reference_time = datetime.now(timezone.utc) + if since is None: + since = reference_time - _window_lookback(window) + + all_trends = await fetch_latest_company_trends(pool, window, since) + + # Group by sector + sectors: dict[str, list[CompanyTrendRow]] = {} + for t in all_trends: + sectors.setdefault(t.sector, []).append(t) + + summaries: list[TrendSummary] = [] + for sector, trends in sectors.items(): + summary = rollup_trends(trends, "sector", sector, window, reference_time) + if trends: + _id = await persist_rollup(pool, summary) + summaries.append(summary) + + return summaries + + +def _window_lookback(window: str) -> timedelta: + """Return a reasonable lookback for finding recent company trends.""" + mapping = { + TrendWindow.INTRADAY.value: timedelta(hours=24), + TrendWindow.ONE_DAY.value: timedelta(days=2), + TrendWindow.SEVEN_DAY.value: timedelta(days=8), + TrendWindow.THIRTY_DAY.value: timedelta(days=35), + TrendWindow.NINETY_DAY.value: timedelta(days=95), + } + return mapping.get(window, timedelta(days=8)) diff --git a/services/aggregation/scoring.py b/services/aggregation/scoring.py new file mode 100644 index 0000000..b66818b --- /dev/null +++ b/services/aggregation/scoring.py @@ -0,0 +1,285 @@ +"""Recency decay, source credibility weighting, and market context +integration for aggregation. + +Provides scoring functions used by the aggregation engine to weight +document intelligence signals when computing trend summaries. + +Requirements: 6.1, 6.2, 6.5 +""" +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from datetime import datetime, timezone + +from services.shared.schemas import MarketContext + + + +@dataclass(frozen=True) +class ScoringConfig: + """Tunable parameters for signal scoring.""" + + # Recency decay: exponential half-life in hours per window. + # After one half-life, a document's recency weight drops to 0.5. + half_life_hours: dict[str, float] = field(default_factory=lambda: { + "intraday": 2.0, + "1d": 12.0, + "7d": 72.0, + "30d": 240.0, + "90d": 720.0, + }) + + # Minimum recency weight — prevents very old docs from being zeroed out + # entirely so they can still contribute trace-level signal. + min_recency_weight: float = 0.01 + + # Source credibility bounds — credibility scores outside this range + # are clamped before weighting. + credibility_floor: float = 0.1 + credibility_ceiling: float = 1.0 + + # Exponent applied to credibility score. >1 penalises low-credibility + # sources more aggressively; <1 flattens the curve. + credibility_exponent: float = 1.0 + + # Novelty bonus: multiplier range applied on top of base weight. + # A novelty_score of 1.0 gets the full bonus; 0.0 gets none. + novelty_bonus_max: float = 0.25 + + # Confidence floor — documents below this extraction confidence + # receive zero weight (they are too unreliable to aggregate). + confidence_floor: float = 0.2 + + # Market context modulation --- + # When volatility exceeds this threshold (in price units), recency + # signals are amplified because fast-moving markets make fresh data + # more important. + volatility_recency_boost_threshold: float = 1.0 + volatility_recency_boost_max: float = 0.30 # max extra multiplier + + # When volume surges above this % change, signals get a small boost + # because high-volume moves carry more conviction. + volume_surge_threshold_pct: float = 50.0 + volume_surge_boost: float = 0.15 + + +# Singleton default config +DEFAULT_CONFIG = ScoringConfig() + + +# --------------------------------------------------------------------------- +# Recency decay +# --------------------------------------------------------------------------- + + +def recency_weight( + published_at: datetime, + reference_time: datetime, + window: str, + config: ScoringConfig = DEFAULT_CONFIG, +) -> float: + """Compute an exponential recency decay weight for a document. + + Uses the formula: w = 2^(-age_hours / half_life) + + Args: + published_at: When the document was published (tz-aware). + reference_time: The "now" anchor for the aggregation window (tz-aware). + window: One of the TrendWindow values (e.g. "7d"). + config: Scoring parameters. + + Returns: + A weight in [config.min_recency_weight, 1.0]. + """ + # Ensure both are tz-aware; treat naive as UTC. + if published_at.tzinfo is None: + published_at = published_at.replace(tzinfo=timezone.utc) + if reference_time.tzinfo is None: + reference_time = reference_time.replace(tzinfo=timezone.utc) + + age_seconds = (reference_time - published_at).total_seconds() + if age_seconds <= 0: + return 1.0 + + age_hours = age_seconds / 3600.0 + half_life = config.half_life_hours.get(window, 72.0) + + weight = math.pow(2.0, -age_hours / half_life) + return max(weight, config.min_recency_weight) + + +# --------------------------------------------------------------------------- +# Source credibility weighting +# --------------------------------------------------------------------------- + + +def credibility_weight( + source_credibility: float, + config: ScoringConfig = DEFAULT_CONFIG, +) -> float: + """Compute a weight from a source's credibility score. + + The raw credibility (0-1) is clamped to [floor, ceiling] then raised + to ``credibility_exponent``. + + Args: + source_credibility: The credibility score from the source or + document intelligence record (0-1). + config: Scoring parameters. + + Returns: + A weight in [floor^exp, ceiling^exp]. + """ + clamped = max(config.credibility_floor, min(source_credibility, config.credibility_ceiling)) + return math.pow(clamped, config.credibility_exponent) + + +# --------------------------------------------------------------------------- +# Market context adjustment +# --------------------------------------------------------------------------- + + +def market_context_multiplier( + market_ctx: MarketContext | None, + config: ScoringConfig = DEFAULT_CONFIG, +) -> float: + """Compute a multiplicative adjustment from market context features. + + Returns a value >= 1.0 that amplifies signal weights when market + conditions suggest heightened importance (high volatility or volume + surges). Returns 1.0 when no market context is available. + """ + if market_ctx is None or not market_ctx.has_data: + return 1.0 + + boost = 0.0 + + # Volatility boost — more volatile markets make recent signals more valuable + if market_ctx.volatility is not None and market_ctx.volatility > config.volatility_recency_boost_threshold: + excess = market_ctx.volatility - config.volatility_recency_boost_threshold + # Logarithmic scaling so extreme volatility doesn't blow up the weight + boost += min( + math.log1p(excess) * 0.15, + config.volatility_recency_boost_max, + ) + + # Volume surge boost + if market_ctx.volume_change_pct is not None and market_ctx.volume_change_pct > config.volume_surge_threshold_pct: + boost += config.volume_surge_boost + + return 1.0 + boost + + +# --------------------------------------------------------------------------- +# Combined document signal weight +# --------------------------------------------------------------------------- + + +@dataclass +class SignalWeight: + """Breakdown of a document's aggregation weight.""" + + recency: float + credibility: float + novelty_bonus: float + confidence_gate: float # 0.0 or 1.0 + market_ctx_multiplier: float # >= 1.0 + combined: float + + +def compute_signal_weight( + published_at: datetime, + reference_time: datetime, + window: str, + source_credibility: float, + novelty_score: float = 0.5, + extraction_confidence: float = 0.5, + market_ctx: MarketContext | None = None, + config: ScoringConfig = DEFAULT_CONFIG, +) -> SignalWeight: + """Compute the combined aggregation weight for a single document signal. + + The formula is: + combined = confidence_gate * recency * credibility + * (1 + novelty_bonus) * market_ctx_multiplier + + where novelty_bonus = novelty_score * config.novelty_bonus_max + and market_ctx_multiplier >= 1.0 based on volatility/volume features. + + Documents with extraction_confidence below config.confidence_floor + receive a combined weight of 0.0 (gated out). + + Args: + published_at: Document publication time. + reference_time: Aggregation anchor time. + window: Trend window identifier. + source_credibility: Source credibility score (0-1). + novelty_score: Document novelty score (0-1). + extraction_confidence: Extraction confidence from the model (0-1). + market_ctx: Optional market context features for the symbol. + config: Scoring parameters. + + Returns: + A ``SignalWeight`` with the component breakdown and combined score. + """ + # Confidence gate + gate = 1.0 if extraction_confidence >= config.confidence_floor else 0.0 + + rec = recency_weight(published_at, reference_time, window, config) + cred = credibility_weight(source_credibility, config) + bonus = novelty_score * config.novelty_bonus_max + mkt_mult = market_context_multiplier(market_ctx, config) + + combined = gate * rec * cred * (1.0 + bonus) * mkt_mult + + return SignalWeight( + recency=rec, + credibility=cred, + novelty_bonus=bonus, + confidence_gate=gate, + market_ctx_multiplier=mkt_mult, + combined=combined, + ) + + +# --------------------------------------------------------------------------- +# Batch helpers +# --------------------------------------------------------------------------- + + +@dataclass +class WeightedSignal: + """A document intelligence reference paired with its computed weight.""" + + document_id: str + weight: SignalWeight + sentiment_value: float # numeric sentiment: +1 positive, -1 negative, 0 neutral/mixed + impact_score: float + + +def sentiment_to_numeric(sentiment: str) -> float: + """Map a sentiment label to a signed numeric value.""" + mapping = { + "positive": 1.0, + "negative": -1.0, + "neutral": 0.0, + "mixed": 0.0, + } + return mapping.get(sentiment.lower(), 0.0) + + +def weighted_sentiment_average(signals: list[WeightedSignal]) -> float: + """Compute a weight-adjusted average sentiment across signals. + + Returns a value in [-1, 1]. Returns 0.0 when total weight is zero. + """ + total_weight = 0.0 + weighted_sum = 0.0 + for sig in signals: + w = sig.weight.combined * sig.impact_score + weighted_sum += w * sig.sentiment_value + total_weight += w + if total_weight == 0.0: + return 0.0 + return weighted_sum / total_weight diff --git a/services/aggregation/worker.py b/services/aggregation/worker.py index fdc81b0..e56823c 100644 --- a/services/aggregation/worker.py +++ b/services/aggregation/worker.py @@ -1 +1,650 @@ -"""Aggregation worker - rolling trend summaries, contradiction detection, evidence ranking.""" +"""Aggregation worker - company-level rolling window trend summaries. + +Queries document intelligence and market context for a given ticker, +computes weighted signal scores, and produces TrendSummary objects +persisted to the trend_windows table. + +Requirements: 6.1, 6.2, 6.5 +""" +from __future__ import annotations + +import json +import logging +import time +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any + +import asyncpg + +from services.aggregation.contradiction import CatalystEntry, detect_contradictions +from services.aggregation.evidence import ( + EvidenceRankConfig, + RankedEvidence, + rank_evidence as _rank_evidence_composite, + rank_evidence_detailed, +) +from services.aggregation.market_context import fetch_market_context +from services.aggregation.scoring import ( + ScoringConfig, + WeightedSignal, + compute_signal_weight, + sentiment_to_numeric, + weighted_sentiment_average, +) +from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow +from services.shared.metrics import ( + AGGREGATION_CONTRADICTION_SCORE, + AGGREGATION_DURATION, + AGGREGATION_SIGNALS_PROCESSED, + AGGREGATION_WINDOWS_COMPUTED, +) + +logger = logging.getLogger(__name__) + +# Map TrendWindow values to lookback durations. +WINDOW_DURATIONS: dict[str, timedelta] = { + TrendWindow.INTRADAY.value: timedelta(hours=12), + TrendWindow.ONE_DAY.value: timedelta(days=1), + TrendWindow.SEVEN_DAY.value: timedelta(days=7), + TrendWindow.THIRTY_DAY.value: timedelta(days=30), + TrendWindow.NINETY_DAY.value: timedelta(days=90), +} + +# How many evidence document IDs to keep in supporting/opposing lists. +MAX_EVIDENCE_REFS = 10 + + +@dataclass +class AggregationConfig: + """Controls which windows to compute and scoring parameters.""" + + windows: list[str] | None = None # None = all windows + scoring: ScoringConfig | None = None + max_evidence: int = MAX_EVIDENCE_REFS + + def effective_windows(self) -> list[str]: + if self.windows: + return self.windows + return [w.value for w in TrendWindow] + + def effective_scoring(self) -> ScoringConfig: + return self.scoring or ScoringConfig() + + + +# --------------------------------------------------------------------------- +# Fetch impact records for a ticker within a time window +# --------------------------------------------------------------------------- + +_IMPACT_QUERY = """ +SELECT + di.document_id, + di.confidence, + di.novelty_score, + di.source_credibility, + dir.sentiment, + dir.impact_score, + dir.catalyst_type, + dir.key_facts, + dir.risks, + d.published_at +FROM document_impact_records dir +JOIN document_intelligence di ON di.id = dir.intelligence_id +JOIN documents d ON d.id = di.document_id +WHERE dir.ticker = $1 + AND d.published_at >= $2 + AND d.published_at <= $3 + AND di.validation_status = 'valid' + AND d.status != 'rejected' +ORDER BY d.published_at DESC +""" + + +@dataclass +class ImpactRow: + """Parsed row from the impact query.""" + + document_id: str + confidence: float + novelty_score: float + source_credibility: float + sentiment: str + impact_score: float + catalyst_type: str + key_facts: list[str] + risks: list[str] + published_at: datetime + + +def _parse_impact_row(row: Any) -> ImpactRow: + """Convert an asyncpg Record to an ImpactRow.""" + key_facts = row["key_facts"] + if isinstance(key_facts, str): + key_facts = json.loads(key_facts) + risks = row["risks"] + if isinstance(risks, str): + risks = json.loads(risks) + + return ImpactRow( + document_id=str(row["document_id"]), + confidence=float(row["confidence"] or 0.5), + novelty_score=float(row["novelty_score"] or 0.5), + source_credibility=float(row["source_credibility"] or 0.5), + sentiment=row["sentiment"] or "neutral", + impact_score=float(row["impact_score"] or 0.0), + catalyst_type=row["catalyst_type"] or "other", + key_facts=key_facts if isinstance(key_facts, list) else [], + risks=risks if isinstance(risks, list) else [], + published_at=row["published_at"], + ) + + +async def fetch_impact_records( + pool: asyncpg.Pool, + ticker: str, + window_start: datetime, + window_end: datetime, +) -> list[ImpactRow]: + """Fetch validated document impact records for a ticker in a time range.""" + rows = await pool.fetch(_IMPACT_QUERY, ticker, window_start, window_end) + return [_parse_impact_row(r) for r in rows] + + + +# --------------------------------------------------------------------------- +# Build weighted signals from impact records +# --------------------------------------------------------------------------- + + +def build_weighted_signals( + impacts: list[ImpactRow], + reference_time: datetime, + window: str, + market_ctx: Any | None = None, + config: ScoringConfig | None = None, +) -> list[WeightedSignal]: + """Convert impact records into WeightedSignal objects using the scoring module.""" + cfg = config or ScoringConfig() + signals: list[WeightedSignal] = [] + for imp in impacts: + sw = compute_signal_weight( + published_at=imp.published_at, + reference_time=reference_time, + window=window, + source_credibility=imp.source_credibility, + novelty_score=imp.novelty_score, + extraction_confidence=imp.confidence, + market_ctx=market_ctx, + config=cfg, + ) + signals.append( + WeightedSignal( + document_id=imp.document_id, + weight=sw, + sentiment_value=sentiment_to_numeric(imp.sentiment), + impact_score=imp.impact_score, + ) + ) + return signals + + +# --------------------------------------------------------------------------- +# Derive trend direction from weighted sentiment +# --------------------------------------------------------------------------- + +# Thresholds for mapping numeric sentiment to direction. +BULLISH_THRESHOLD = 0.15 +BEARISH_THRESHOLD = -0.15 +MIXED_THRESHOLD = 0.10 # contradiction score above this → mixed + + +def derive_trend_direction( + avg_sentiment: float, + contradiction_score: float = 0.0, +) -> TrendDirection: + """Map a weighted average sentiment to a TrendDirection. + + If contradiction is high, the direction is MIXED regardless of + the average sentiment value. + """ + if contradiction_score > MIXED_THRESHOLD and abs(avg_sentiment) < 0.3: + return TrendDirection.MIXED + if avg_sentiment >= BULLISH_THRESHOLD: + return TrendDirection.BULLISH + if avg_sentiment <= BEARISH_THRESHOLD: + return TrendDirection.BEARISH + return TrendDirection.NEUTRAL + + +# --------------------------------------------------------------------------- +# Compute contradiction score +# --------------------------------------------------------------------------- + + +def compute_contradiction_score(signals: list[WeightedSignal]) -> float: + """Measure how much disagreement exists among weighted signals. + + Returns a value in [0, 1] where 0 means full agreement and 1 means + equal-weight positive and negative signals. + + The formula computes the ratio of the minority-side total weight to + the majority-side total weight. + """ + if not signals: + return 0.0 + + pos_weight = 0.0 + neg_weight = 0.0 + for sig in signals: + w = sig.weight.combined * sig.impact_score + if sig.sentiment_value > 0: + pos_weight += w + elif sig.sentiment_value < 0: + neg_weight += w + + total = pos_weight + neg_weight + if total == 0.0: + return 0.0 + + minority = min(pos_weight, neg_weight) + return round(minority / total, 4) + + +# --------------------------------------------------------------------------- +# Rank evidence (supporting vs opposing) +# --------------------------------------------------------------------------- + + +def rank_evidence( + signals: list[WeightedSignal], + max_refs: int = MAX_EVIDENCE_REFS, +) -> tuple[list[str], list[str]]: + """Return top supporting and opposing document IDs ranked by composite score. + + Delegates to the evidence ranking module which considers multiple + factors (weight, impact, recency, confidence) rather than raw weight alone. + + Supporting = positive sentiment, Opposing = negative sentiment. + Neutral/mixed signals are excluded from evidence lists. + """ + config = EvidenceRankConfig(max_refs=max_refs) + return _rank_evidence_composite(signals, config) + + +# --------------------------------------------------------------------------- +# Extract dominant catalysts and material risks +# --------------------------------------------------------------------------- + + +def extract_catalysts_and_risks( + impacts: list[ImpactRow], + signals: list[WeightedSignal], +) -> tuple[list[str], list[str]]: + """Return dominant catalyst types and material risks weighted by signal strength. + + Catalysts are ranked by cumulative weight. Risks are deduplicated and + ordered by the weight of the signal that surfaced them. + """ + catalyst_weights: dict[str, float] = {} + risk_entries: list[tuple[float, str]] = [] + + # Build a lookup from document_id to combined weight + weight_by_doc = {s.document_id: s.weight.combined * s.impact_score for s in signals} + + for imp in impacts: + w = weight_by_doc.get(imp.document_id, 0.0) + if w <= 0.0: + continue + catalyst_weights[imp.catalyst_type] = catalyst_weights.get(imp.catalyst_type, 0.0) + w + for risk in imp.risks: + risk_entries.append((w, risk)) + + # Top catalysts by cumulative weight + sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True) + catalysts = [cat for cat, _ in sorted_catalysts[:5]] + + # Deduplicated risks ordered by weight + seen_risks: set[str] = set() + risks: list[str] = [] + risk_entries.sort(key=lambda x: x[0], reverse=True) + for _, risk_text in risk_entries: + normalized = risk_text.strip().lower() + if normalized not in seen_risks: + seen_risks.add(normalized) + risks.append(risk_text.strip()) + if len(risks) >= 5: + break + + return catalysts, risks + + + +# --------------------------------------------------------------------------- +# Compute trend confidence +# --------------------------------------------------------------------------- + + +def compute_trend_confidence( + signals: list[WeightedSignal], + contradiction_score: float, +) -> float: + """Derive an overall confidence for the trend summary. + + Confidence is based on: + - Number of contributing signals (more = higher base) + - Average extraction confidence of contributing signals + - Contradiction penalty (high contradiction lowers confidence) + + Returns a value in [0, 1]. + """ + if not signals: + return 0.0 + + active = [s for s in signals if s.weight.combined > 0] + if not active: + return 0.0 + + # Base confidence from signal count (diminishing returns) + count_factor = min(len(active) / 20.0, 1.0) + + # Average extraction confidence (from the confidence_gate — if gated, + # the signal wouldn't be in active list, so we use the raw confidence + # from the weight breakdown). + avg_conf = sum(s.weight.credibility for s in active) / len(active) + + # Contradiction penalty + contradiction_penalty = contradiction_score * 0.4 + + confidence = (0.4 * count_factor + 0.6 * avg_conf) - contradiction_penalty + return round(max(0.0, min(1.0, confidence)), 4) + + +# --------------------------------------------------------------------------- +# Assemble a TrendSummary from components +# --------------------------------------------------------------------------- + + +@dataclass +class AssembledTrend: + """A trend summary paired with its detailed evidence rankings.""" + + summary: TrendSummary + supporting_evidence: list[RankedEvidence] + opposing_evidence: list[RankedEvidence] + + +def assemble_trend_summary( + ticker: str, + window: str, + signals: list[WeightedSignal], + impacts: list[ImpactRow], + market_ctx: Any | None = None, + max_evidence: int = MAX_EVIDENCE_REFS, + reference_time: datetime | None = None, +) -> TrendSummary: + """Build a complete TrendSummary from weighted signals and impact records.""" + result = assemble_trend_with_evidence( + ticker, window, signals, impacts, market_ctx, max_evidence, reference_time, + ) + return result.summary + + +def assemble_trend_with_evidence( + ticker: str, + window: str, + signals: list[WeightedSignal], + impacts: list[ImpactRow], + market_ctx: Any | None = None, + max_evidence: int = MAX_EVIDENCE_REFS, + reference_time: datetime | None = None, +) -> AssembledTrend: + """Build a TrendSummary and return detailed evidence rankings for persistence.""" + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + avg_sentiment = weighted_sentiment_average(signals) + + # Run full contradiction detection (Requirement 6.4) + catalyst_entries = [ + CatalystEntry(document_id=imp.document_id, catalyst_type=imp.catalyst_type) + for imp in impacts + ] + contradiction_result = detect_contradictions(signals, catalyst_entries) + contradiction = contradiction_result.score + + direction = derive_trend_direction(avg_sentiment, contradiction) + confidence = compute_trend_confidence(signals, contradiction) + + # Get detailed evidence rankings for persistence + config = EvidenceRankConfig(max_refs=max_evidence) + supporting_ranked, opposing_ranked = rank_evidence_detailed(signals, config) + + supporting = [r.document_id for r in supporting_ranked] + opposing = [r.document_id for r in opposing_ranked] + + catalysts, risks = extract_catalysts_and_risks(impacts, signals) + + # Trend strength: absolute value of weighted sentiment, clamped to [0, 1] + strength = round(min(abs(avg_sentiment), 1.0), 4) + + summary = TrendSummary( + entity_type="company", + entity_id=ticker, + window=TrendWindow(window), + trend_direction=direction, + trend_strength=strength, + confidence=confidence, + top_supporting_evidence=supporting, + top_opposing_evidence=opposing, + dominant_catalysts=catalysts, + material_risks=risks, + contradiction_score=contradiction, + disagreement_details=contradiction_result.details, + market_context=market_ctx, + generated_at=reference_time, + ) + + return AssembledTrend( + summary=summary, + supporting_evidence=supporting_ranked, + opposing_evidence=opposing_ranked, + ) + + +# --------------------------------------------------------------------------- +# Persist trend summary to PostgreSQL +# --------------------------------------------------------------------------- + +_UPSERT_TREND = """ +INSERT INTO trend_windows ( + entity_type, entity_id, window, trend_direction, trend_strength, + confidence, top_supporting_evidence, top_opposing_evidence, + dominant_catalysts, material_risks, contradiction_score, + disagreement_details, market_context, generated_at +) VALUES ( + $1, $2, $3, $4, $5, + $6, $7::jsonb, $8::jsonb, + $9::jsonb, $10::jsonb, $11, + $12::jsonb, $13::jsonb, $14 +) +RETURNING id +""" + + +async def persist_trend_summary( + pool: asyncpg.Pool, + summary: TrendSummary, +) -> str: + """Insert a trend summary row and return its UUID.""" + row = await pool.fetchrow( + _UPSERT_TREND, + summary.entity_type, + summary.entity_id, + summary.window.value, + summary.trend_direction.value, + summary.trend_strength, + summary.confidence, + json.dumps(summary.top_supporting_evidence), + json.dumps(summary.top_opposing_evidence), + json.dumps(summary.dominant_catalysts), + json.dumps(summary.material_risks), + summary.contradiction_score, + json.dumps([d.model_dump() for d in summary.disagreement_details]), + json.dumps(summary.market_context.model_dump() if summary.market_context else {}), + summary.generated_at, + ) + return str(row["id"]) + + +# --------------------------------------------------------------------------- +# Persist evidence mappings to trend_evidence table +# --------------------------------------------------------------------------- + +_INSERT_EVIDENCE = """ +INSERT INTO trend_evidence ( + trend_window_id, document_id, evidence_type, + rank_score, weight_component, impact_component, + recency_component, confidence_component, sentiment_value +) VALUES ( + $1, $2::uuid, $3, + $4, $5, $6, + $7, $8, $9 +) +""" + + +async def persist_trend_evidence( + pool: asyncpg.Pool, + trend_window_id: str, + supporting: list[RankedEvidence], + opposing: list[RankedEvidence], +) -> int: + """Insert evidence mapping rows for a trend window. Returns count inserted.""" + rows: list[tuple[str, str, str, float, float, float, float, float, float]] = [] + for ev in supporting: + rows.append(( + trend_window_id, ev.document_id, "supporting", + ev.rank_score, ev.weight_component, ev.impact_component, + ev.recency_component, ev.confidence_component, ev.sentiment_value, + )) + for ev in opposing: + rows.append(( + trend_window_id, ev.document_id, "opposing", + ev.rank_score, ev.weight_component, ev.impact_component, + ev.recency_component, ev.confidence_component, ev.sentiment_value, + )) + + if not rows: + return 0 + + await pool.executemany(_INSERT_EVIDENCE, rows) + return len(rows) + + +# --------------------------------------------------------------------------- +# Main aggregation entry point for a single ticker + window +# --------------------------------------------------------------------------- + + +async def aggregate_company_window( + pool: asyncpg.Pool, + ticker: str, + window: str, + reference_time: datetime | None = None, + config: AggregationConfig | None = None, +) -> TrendSummary: + """Compute and persist a trend summary for one ticker and one window. + + Steps: + 1. Determine the time range for the window. + 2. Fetch document impact records from PostgreSQL. + 3. Fetch market context for the ticker. + 4. Build weighted signals using the scoring module. + 5. Assemble the TrendSummary. + 6. Persist to trend_windows table. + + Returns the assembled TrendSummary. + """ + cfg = config or AggregationConfig() + scoring_cfg = cfg.effective_scoring() + + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + _agg_start = time.monotonic() + duration = WINDOW_DURATIONS.get(window, timedelta(days=7)) + window_start = reference_time - duration + + # 1. Fetch impact records + impacts = await fetch_impact_records(pool, ticker, window_start, reference_time) + + # 2. Fetch market context + market_ctx = await fetch_market_context(pool, ticker, window, reference_time) + + # 3. Build weighted signals + signals = build_weighted_signals( + impacts, reference_time, window, market_ctx, scoring_cfg, + ) + + # 4. Assemble trend summary with evidence details + assembled = assemble_trend_with_evidence( + ticker=ticker, + window=window, + signals=signals, + impacts=impacts, + market_ctx=market_ctx if market_ctx.has_data else None, + max_evidence=cfg.max_evidence, + reference_time=reference_time, + ) + summary = assembled.summary + + # 5. Persist trend window + trend_id = await persist_trend_summary(pool, summary) + + # 6. Persist evidence mappings + evidence_count = await persist_trend_evidence( + pool, trend_id, + assembled.supporting_evidence, + assembled.opposing_evidence, + ) + + logger.info( + "Persisted trend %s for %s/%s: direction=%s strength=%.3f confidence=%.3f signals=%d evidence=%d", + trend_id, ticker, window, summary.trend_direction.value, + summary.trend_strength, summary.confidence, len(signals), evidence_count, + ) + + # Prometheus metrics + AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc() + AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals)) + AGGREGATION_CONTRADICTION_SCORE.observe(summary.contradiction_score) + AGGREGATION_DURATION.labels(window=window).observe(time.monotonic() - _agg_start) + + return summary + + +# --------------------------------------------------------------------------- +# Aggregate all windows for a single ticker +# --------------------------------------------------------------------------- + + +async def aggregate_company( + pool: asyncpg.Pool, + ticker: str, + reference_time: datetime | None = None, + config: AggregationConfig | None = None, +) -> list[TrendSummary]: + """Compute trend summaries for all configured windows for a ticker.""" + cfg = config or AggregationConfig() + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + summaries: list[TrendSummary] = [] + for window in cfg.effective_windows(): + summary = await aggregate_company_window( + pool, ticker, window, reference_time, cfg, + ) + summaries.append(summary) + + return summaries diff --git a/services/api/app.py b/services/api/app.py index 54a4637..2d3f2ae 100644 --- a/services/api/app.py +++ b/services/api/app.py @@ -1 +1,1507 @@ -"""Query API - FastAPI application for analytics, evidence drill-down, and admin controls.""" +"""Query API - FastAPI application for analytics, evidence drill-down, and admin controls. + +Exposes read-only endpoints for: +- Companies and watchlists (proxied from symbol registry data) +- Document timelines with intelligence +- Trend summaries +- Recommendation history with evidence +- Order history with audit trails + +Requirements: 11.1, 11.2, 11.3 +Design: Section 9.1 (Operational API) +""" +from __future__ import annotations + +import json +import logging +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import Any, Optional + +import asyncpg +from fastapi import FastAPI, HTTPException, Query, Request +from starlette.middleware.base import BaseHTTPMiddleware + +from services.shared.audit import get_entity_audit_trail, get_order_audit_trail +from services.shared.config import load_config +from services.shared.db import get_pg_pool +from services.shared.logging import new_trace_id, set_trace_context, setup_logging +from services.extractor.metrics import get_model_performance_summary +from prometheus_client import generate_latest, CONTENT_TYPE_LATEST +from starlette.responses import Response + +logger = logging.getLogger("query_api") + +config = load_config() +pool: Optional[asyncpg.Pool] = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global pool + setup_logging("query_api", level=config.log_level, json_output=config.json_logs) + pool = await get_pg_pool(config) + yield + await pool.close() + + +app = FastAPI(title="Stonks Oracle - Query API", lifespan=lifespan) + + +class TraceMiddleware(BaseHTTPMiddleware): + """Inject trace context for every incoming HTTP request.""" + + async def dispatch(self, request: Request, call_next): + trace_id = request.headers.get("x-trace-id") or new_trace_id() + set_trace_context(trace_id=trace_id) + response = await call_next(request) + response.headers["x-trace-id"] = trace_id + return response + + +app.add_middleware(TraceMiddleware) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _row_to_dict(row: asyncpg.Record) -> dict[str, Any]: + """Convert an asyncpg Record to a JSON-safe dict.""" + d: dict[str, Any] = {} + for key, val in dict(row).items(): + if isinstance(val, datetime): + d[key] = val.isoformat() + elif hasattr(val, "__str__") and not isinstance(val, (str, int, float, bool, list, dict, type(None))): + d[key] = str(val) + else: + d[key] = val + return d + + +def _parse_jsonb(val: Any) -> Any: + """Parse a JSONB value that may come back as str or already-decoded.""" + if val is None: + return None + if isinstance(val, (dict, list)): + return val + try: + return json.loads(val) + except (json.JSONDecodeError, TypeError): + return val + + +# --------------------------------------------------------------------------- +# Health +# --------------------------------------------------------------------------- + +@app.get("/health") +async def health(): + try: + await pool.fetchval("SELECT 1") + return {"status": "ok"} + except Exception: + raise HTTPException(503, "Database unavailable") + + +@app.get("/metrics") +async def metrics(): + """Expose Prometheus metrics for scraping. + + Requirements: 12.1, 12.2 + """ + return Response( + content=generate_latest(), + media_type=CONTENT_TYPE_LATEST, + ) + + +# --------------------------------------------------------------------------- +# Companies (Requirement 11.1) +# --------------------------------------------------------------------------- + + +@app.get("/api/companies") +async def list_companies( + active: bool = True, + sector: Optional[str] = None, + ticker: Optional[str] = None, +): + """List tracked companies with optional filters.""" + conditions = ["c.active = $1"] + params: list[Any] = [active] + idx = 2 + + if sector: + conditions.append(f"c.sector = ${idx}") + params.append(sector) + idx += 1 + if ticker: + conditions.append(f"c.ticker = ${idx}") + params.append(ticker.upper()) + idx += 1 + + where = " AND ".join(conditions) + rows = await pool.fetch( + f"""SELECT c.id, c.ticker, c.legal_name, c.exchange, c.sector, + c.industry, c.market_cap_bucket, c.active, + c.created_at, c.updated_at + FROM companies c + WHERE {where} + ORDER BY c.ticker""", + *params, + ) + return [_row_to_dict(r) for r in rows] + + +@app.get("/api/companies/{company_id}") +async def get_company(company_id: str): + """Get a single company with aliases and source count.""" + row = await pool.fetchrow( + """SELECT id, ticker, legal_name, exchange, sector, industry, + market_cap_bucket, active, created_at, updated_at + FROM companies WHERE id = $1""", + company_id, + ) + if not row: + raise HTTPException(404, "Company not found") + + result = _row_to_dict(row) + + aliases = await pool.fetch( + "SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1", + company_id, + ) + result["aliases"] = [dict(a) for a in aliases] + + source_count = await pool.fetchval( + "SELECT COUNT(*) FROM sources WHERE company_id = $1 AND active = true", + company_id, + ) + result["active_source_count"] = source_count + + return result + + +@app.get("/api/companies/{company_id}/sources") +async def list_company_sources(company_id: str): + """List sources configured for a company.""" + rows = await pool.fetch( + """SELECT id, source_type, source_name, config, credibility_score, + retention_days, access_policy, active + FROM sources WHERE company_id = $1 ORDER BY source_type""", + company_id, + ) + return [_row_to_dict(r) for r in rows] + + +# --------------------------------------------------------------------------- +# Document Timelines (Requirement 11.1, 11.2) +# --------------------------------------------------------------------------- + +@app.get("/api/documents") +async def list_documents( + ticker: Optional[str] = None, + company_id: Optional[str] = None, + document_type: Optional[str] = None, + status: Optional[str] = None, + since: Optional[str] = None, + limit: int = Query(default=50, le=200), + offset: int = 0, +): + """List documents with optional filters, ordered by published_at desc.""" + conditions: list[str] = [] + params: list[Any] = [] + idx = 1 + + if ticker: + conditions.append(f"""d.id IN ( + SELECT document_id FROM document_company_mentions WHERE ticker = ${idx} + )""") + params.append(ticker.upper()) + idx += 1 + if company_id: + conditions.append(f"""d.id IN ( + SELECT document_id FROM document_company_mentions WHERE company_id = ${idx} + )""") + params.append(company_id) + idx += 1 + if document_type: + conditions.append(f"d.document_type = ${idx}") + params.append(document_type) + idx += 1 + if status: + conditions.append(f"d.status = ${idx}") + params.append(status) + idx += 1 + if since: + conditions.append(f"d.published_at >= ${idx}::timestamptz") + params.append(since) + idx += 1 + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + + rows = await pool.fetch( + f"""SELECT d.id, d.document_type, d.source_type, d.publisher, d.url, + d.title, d.published_at, d.retrieved_at, d.language, + d.content_hash, d.parse_quality_score, d.parse_confidence, + d.status, d.created_at + FROM documents d + {where} + ORDER BY d.published_at DESC NULLS LAST + LIMIT ${idx} OFFSET ${idx + 1}""", + *params, limit, offset, + ) + return [_row_to_dict(r) for r in rows] + + +@app.get("/api/documents/{document_id}") +async def get_document(document_id: str): + """Get a single document with its intelligence extraction and company mentions.""" + row = await pool.fetchrow( + """SELECT id, document_type, source_type, publisher, url, canonical_url, + title, published_at, retrieved_at, language, content_hash, + raw_storage_ref, normalized_storage_ref, + parse_quality_score, parse_confidence, status, + created_at, updated_at + FROM documents WHERE id = $1""", + document_id, + ) + if not row: + raise HTTPException(404, "Document not found") + + result = _row_to_dict(row) + + # Company mentions + mentions = await pool.fetch( + """SELECT dcm.company_id, dcm.ticker, dcm.mention_type, dcm.confidence, + c.legal_name + FROM document_company_mentions dcm + JOIN companies c ON c.id = dcm.company_id + WHERE dcm.document_id = $1""", + document_id, + ) + result["company_mentions"] = [_row_to_dict(m) for m in mentions] + + # Intelligence extraction + intel = await pool.fetchrow( + """SELECT id, summary, macro_themes, novelty_score, source_credibility, + extraction_warnings, confidence, model_provider, model_name, + prompt_version, schema_version, validation_status, + validation_errors, created_at + FROM document_intelligence WHERE document_id = $1 + ORDER BY created_at DESC LIMIT 1""", + document_id, + ) + if intel: + intel_dict = _row_to_dict(intel) + intel_dict["macro_themes"] = _parse_jsonb(intel_dict.get("macro_themes")) + intel_dict["extraction_warnings"] = _parse_jsonb(intel_dict.get("extraction_warnings")) + intel_dict["validation_errors"] = _parse_jsonb(intel_dict.get("validation_errors")) + + # Impact records per company + impacts = await pool.fetch( + """SELECT dir.company_id, dir.ticker, dir.relevance, dir.sentiment, + dir.impact_score, dir.impact_horizon, dir.catalyst_type, + dir.key_facts, dir.risks, dir.evidence_spans, + c.legal_name + FROM document_impact_records dir + JOIN companies c ON c.id = dir.company_id + WHERE dir.intelligence_id = $1""", + intel["id"], + ) + impact_list = [] + for imp in impacts: + imp_dict = _row_to_dict(imp) + imp_dict["key_facts"] = _parse_jsonb(imp_dict.get("key_facts")) + imp_dict["risks"] = _parse_jsonb(imp_dict.get("risks")) + imp_dict["evidence_spans"] = _parse_jsonb(imp_dict.get("evidence_spans")) + impact_list.append(imp_dict) + intel_dict["company_impacts"] = impact_list + result["intelligence"] = intel_dict + else: + result["intelligence"] = None + + return result + + +# --------------------------------------------------------------------------- +# Trend Summaries (Requirement 11.1) +# --------------------------------------------------------------------------- + + +@app.get("/api/trends") +async def list_trends( + ticker: Optional[str] = None, + entity_type: str = "company", + window: Optional[str] = None, + limit: int = Query(default=50, le=200), + offset: int = 0, +): + """List trend summaries with optional filters.""" + conditions = [f"entity_type = $1"] + params: list[Any] = [entity_type] + idx = 2 + + if ticker: + conditions.append(f"entity_id = ${idx}") + params.append(ticker.upper()) + idx += 1 + if window: + conditions.append(f"window = ${idx}") + params.append(window) + idx += 1 + + where = " AND ".join(conditions) + rows = await pool.fetch( + f"""SELECT id, entity_type, entity_id, window, trend_direction, + trend_strength, confidence, top_supporting_evidence, + top_opposing_evidence, dominant_catalysts, material_risks, + contradiction_score, market_context, generated_at + FROM trend_windows + WHERE {where} + ORDER BY generated_at DESC + LIMIT ${idx} OFFSET ${idx + 1}""", + *params, limit, offset, + ) + results = [] + for r in rows: + d = _row_to_dict(r) + for jsonb_field in ( + "top_supporting_evidence", "top_opposing_evidence", + "dominant_catalysts", "material_risks", "market_context", + ): + d[jsonb_field] = _parse_jsonb(d.get(jsonb_field)) + results.append(d) + return results + + +@app.get("/api/trends/{trend_id}") +async def get_trend(trend_id: str): + """Get a single trend summary by ID.""" + row = await pool.fetchrow( + """SELECT id, entity_type, entity_id, window, trend_direction, + trend_strength, confidence, top_supporting_evidence, + top_opposing_evidence, dominant_catalysts, material_risks, + contradiction_score, market_context, generated_at, created_at + FROM trend_windows WHERE id = $1""", + trend_id, + ) + if not row: + raise HTTPException(404, "Trend not found") + + d = _row_to_dict(row) + for jsonb_field in ( + "top_supporting_evidence", "top_opposing_evidence", + "dominant_catalysts", "material_risks", "market_context", + ): + d[jsonb_field] = _parse_jsonb(d.get(jsonb_field)) + return d + + +# --------------------------------------------------------------------------- +# Recommendations (Requirement 11.1, 11.2) +# --------------------------------------------------------------------------- + +@app.get("/api/recommendations") +async def list_recommendations( + ticker: Optional[str] = None, + action: Optional[str] = None, + mode: Optional[str] = None, + since: Optional[str] = None, + limit: int = Query(default=50, le=200), + offset: int = 0, +): + """List recommendations with optional filters.""" + conditions: list[str] = [] + params: list[Any] = [] + idx = 1 + + if ticker: + conditions.append(f"r.ticker = ${idx}") + params.append(ticker.upper()) + idx += 1 + if action: + conditions.append(f"r.action = ${idx}") + params.append(action) + idx += 1 + if mode: + conditions.append(f"r.mode = ${idx}") + params.append(mode) + idx += 1 + if since: + conditions.append(f"r.generated_at >= ${idx}::timestamptz") + params.append(since) + idx += 1 + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + + rows = await pool.fetch( + f"""SELECT r.id, r.ticker, r.action, r.mode, r.confidence, + r.time_horizon, r.thesis, r.invalidation_conditions, + r.portfolio_pct, r.max_loss_pct, r.model_version, + r.risk_classification, r.generated_at + FROM recommendations r + {where} + ORDER BY r.generated_at DESC + LIMIT ${idx} OFFSET ${idx + 1}""", + *params, limit, offset, + ) + results = [] + for r in rows: + d = _row_to_dict(r) + d["invalidation_conditions"] = _parse_jsonb(d.get("invalidation_conditions")) + results.append(d) + return results + + +@app.get("/api/recommendations/{recommendation_id}") +async def get_recommendation(recommendation_id: str): + """Get a single recommendation with evidence and risk evaluation. + + Requirement 11.2: display contributing intelligence objects, raw sources, + and market context that influenced the decision. + """ + row = await pool.fetchrow( + """SELECT r.id, r.ticker, r.company_id, r.action, r.mode, r.confidence, + r.time_horizon, r.thesis, r.invalidation_conditions, + r.portfolio_pct, r.max_loss_pct, r.model_version, + r.model_provider, r.prompt_version, r.schema_version, + r.risk_classification, r.generated_at, r.created_at + FROM recommendations r WHERE r.id = $1""", + recommendation_id, + ) + if not row: + raise HTTPException(404, "Recommendation not found") + + result = _row_to_dict(row) + result["invalidation_conditions"] = _parse_jsonb(result.get("invalidation_conditions")) + + # Evidence: linked documents and intelligence objects + evidence_rows = await pool.fetch( + """SELECT re.id, re.document_id, re.intelligence_id, re.evidence_type, re.weight, + d.title, d.document_type, d.source_type, d.publisher, d.url, + d.published_at + FROM recommendation_evidence re + LEFT JOIN documents d ON d.id = re.document_id + WHERE re.recommendation_id = $1 + ORDER BY re.weight DESC""", + recommendation_id, + ) + result["evidence"] = [_row_to_dict(e) for e in evidence_rows] + + # Risk evaluation + risk_row = await pool.fetchrow( + """SELECT id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at + FROM risk_evaluations WHERE recommendation_id = $1 + ORDER BY evaluated_at DESC LIMIT 1""", + recommendation_id, + ) + if risk_row: + risk_dict = _row_to_dict(risk_row) + risk_dict["rejection_reasons"] = _parse_jsonb(risk_dict.get("rejection_reasons")) + risk_dict["risk_checks"] = _parse_jsonb(risk_dict.get("risk_checks")) + result["risk_evaluation"] = risk_dict + else: + result["risk_evaluation"] = None + + return result + + +# --------------------------------------------------------------------------- +# Evidence Drill-Down (Requirement 11.2, 10.4) +# --------------------------------------------------------------------------- + +@app.get("/api/recommendations/{recommendation_id}/evidence") +async def get_recommendation_evidence_drilldown(recommendation_id: str): + """Full evidence drill-down linking a recommendation to source documents and raw artifacts. + + Returns the complete provenance chain for each piece of evidence: + recommendation_evidence → document (with storage refs) → document_intelligence + → document_impact_records, plus the trend window that fed the recommendation. + + Requirements: 11.2, 10.4 + Design: Section 9.1 (evidence drill-down and audit views) + """ + # Verify recommendation exists and get basic info + rec_row = await pool.fetchrow( + """SELECT id, ticker, company_id, action, mode, confidence, + time_horizon, thesis, model_version, model_provider, + prompt_version, schema_version, generated_at + FROM recommendations WHERE id = $1""", + recommendation_id, + ) + if not rec_row: + raise HTTPException(404, "Recommendation not found") + + result: dict[str, Any] = { + "recommendation": _row_to_dict(rec_row), + "evidence": [], + "trend_window": None, + } + + # Fetch evidence rows with full document details including storage refs + evidence_rows = await pool.fetch( + """SELECT re.id AS evidence_id, + re.document_id, + re.intelligence_id, + re.evidence_type, + re.weight, + d.document_type, + d.source_type, + d.publisher, + d.url, + d.canonical_url, + d.title, + d.published_at, + d.retrieved_at, + d.language, + d.content_hash, + d.raw_storage_ref, + d.normalized_storage_ref, + d.parse_quality_score, + d.parse_confidence, + d.status AS document_status + FROM recommendation_evidence re + LEFT JOIN documents d ON d.id = re.document_id + WHERE re.recommendation_id = $1 + ORDER BY re.weight DESC""", + recommendation_id, + ) + + for ev in evidence_rows: + ev_dict = _row_to_dict(ev) + ev_dict["intelligence"] = None + ev_dict["company_impacts"] = [] + + # Fetch intelligence extraction for this evidence + intel_id = ev["intelligence_id"] + doc_id = ev["document_id"] + + # Use the linked intelligence_id if available, otherwise look up by document_id + intel_row = None + if intel_id: + intel_row = await pool.fetchrow( + """SELECT id, document_id, summary, macro_themes, novelty_score, + source_credibility, extraction_warnings, confidence, + model_provider, model_name, prompt_version, schema_version, + raw_output_ref, prompt_ref, validation_status, + validation_errors, created_at + FROM document_intelligence WHERE id = $1""", + intel_id, + ) + elif doc_id: + intel_row = await pool.fetchrow( + """SELECT id, document_id, summary, macro_themes, novelty_score, + source_credibility, extraction_warnings, confidence, + model_provider, model_name, prompt_version, schema_version, + raw_output_ref, prompt_ref, validation_status, + validation_errors, created_at + FROM document_intelligence WHERE document_id = $1 + ORDER BY created_at DESC LIMIT 1""", + doc_id, + ) + + if intel_row: + intel_dict = _row_to_dict(intel_row) + for jf in ("macro_themes", "extraction_warnings", "validation_errors"): + intel_dict[jf] = _parse_jsonb(intel_dict.get(jf)) + ev_dict["intelligence"] = intel_dict + + # Fetch per-company impact records for this intelligence + impacts = await pool.fetch( + """SELECT dir.company_id, dir.ticker, dir.relevance, dir.sentiment, + dir.impact_score, dir.impact_horizon, dir.catalyst_type, + dir.key_facts, dir.risks, dir.evidence_spans, + c.legal_name + FROM document_impact_records dir + JOIN companies c ON c.id = dir.company_id + WHERE dir.intelligence_id = $1""", + intel_row["id"], + ) + impact_list = [] + for imp in impacts: + imp_dict = _row_to_dict(imp) + for jf in ("key_facts", "risks", "evidence_spans"): + imp_dict[jf] = _parse_jsonb(imp_dict.get(jf)) + impact_list.append(imp_dict) + ev_dict["company_impacts"] = impact_list + + result["evidence"].append(ev_dict) + + # Fetch the most recent trend window for this ticker to show market context + ticker = rec_row["ticker"] + generated_at = rec_row["generated_at"] + if ticker and generated_at: + trend_row = await pool.fetchrow( + """SELECT id, entity_type, entity_id, window, trend_direction, + trend_strength, confidence, top_supporting_evidence, + top_opposing_evidence, dominant_catalysts, material_risks, + contradiction_score, market_context, generated_at + FROM trend_windows + WHERE entity_id = $1 AND entity_type = 'company' + AND generated_at <= $2 + ORDER BY generated_at DESC LIMIT 1""", + ticker, generated_at, + ) + if trend_row: + trend_dict = _row_to_dict(trend_row) + for jf in ( + "top_supporting_evidence", "top_opposing_evidence", + "dominant_catalysts", "material_risks", "market_context", + ): + trend_dict[jf] = _parse_jsonb(trend_dict.get(jf)) + + # Include trend evidence linkage: documents that contributed to this trend + trend_ev_rows = await pool.fetch( + """SELECT te.id, te.document_id, te.evidence_type, te.rank_score, + te.weight_component, te.impact_component, + te.recency_component, te.confidence_component, + te.sentiment_value, + d.title, d.document_type, d.source_type, d.publisher, + d.url, d.published_at, d.raw_storage_ref, + d.normalized_storage_ref + FROM trend_evidence te + LEFT JOIN documents d ON d.id = te.document_id + WHERE te.trend_window_id = $1 + ORDER BY te.rank_score DESC""", + trend_row["id"], + ) + trend_dict["evidence"] = [_row_to_dict(te) for te in trend_ev_rows] + + result["trend_window"] = trend_dict + + return result + + +# --------------------------------------------------------------------------- +# Trend Evidence Drill-Down (Requirement 10.4) +# --------------------------------------------------------------------------- + +@app.get("/api/trends/{trend_id}/evidence") +async def get_trend_evidence_drilldown(trend_id: str): + """Drill down from a trend window to its contributing documents and raw artifacts. + + Returns the trend summary plus each contributing document with storage refs, + intelligence extraction, and impact records — full provenance chain. + + Requirements: 10.4, 6.5 + """ + trend_row = await pool.fetchrow( + """SELECT id, entity_type, entity_id, window, trend_direction, + trend_strength, confidence, top_supporting_evidence, + top_opposing_evidence, dominant_catalysts, material_risks, + contradiction_score, market_context, generated_at + FROM trend_windows WHERE id = $1""", + trend_id, + ) + if not trend_row: + raise HTTPException(404, "Trend not found") + + trend_dict = _row_to_dict(trend_row) + for jf in ( + "top_supporting_evidence", "top_opposing_evidence", + "dominant_catalysts", "material_risks", "market_context", + ): + trend_dict[jf] = _parse_jsonb(trend_dict.get(jf)) + + # Fetch trend evidence with full document details + evidence_rows = await pool.fetch( + """SELECT te.id AS evidence_id, + te.document_id, + te.evidence_type, + te.rank_score, + te.weight_component, + te.impact_component, + te.recency_component, + te.confidence_component, + te.sentiment_value, + d.document_type, + d.source_type, + d.publisher, + d.url, + d.canonical_url, + d.title, + d.published_at, + d.retrieved_at, + d.content_hash, + d.raw_storage_ref, + d.normalized_storage_ref, + d.parse_quality_score, + d.parse_confidence, + d.status AS document_status + FROM trend_evidence te + LEFT JOIN documents d ON d.id = te.document_id + WHERE te.trend_window_id = $1 + ORDER BY te.rank_score DESC""", + trend_id, + ) + + evidence_list = [] + for ev in evidence_rows: + ev_dict = _row_to_dict(ev) + ev_dict["intelligence"] = None + ev_dict["company_impacts"] = [] + + doc_id = ev["document_id"] + if doc_id: + intel_row = await pool.fetchrow( + """SELECT id, document_id, summary, macro_themes, novelty_score, + source_credibility, extraction_warnings, confidence, + model_provider, model_name, prompt_version, schema_version, + raw_output_ref, prompt_ref, validation_status, + validation_errors, created_at + FROM document_intelligence WHERE document_id = $1 + ORDER BY created_at DESC LIMIT 1""", + doc_id, + ) + if intel_row: + intel_dict = _row_to_dict(intel_row) + for jf in ("macro_themes", "extraction_warnings", "validation_errors"): + intel_dict[jf] = _parse_jsonb(intel_dict.get(jf)) + ev_dict["intelligence"] = intel_dict + + impacts = await pool.fetch( + """SELECT dir.company_id, dir.ticker, dir.relevance, dir.sentiment, + dir.impact_score, dir.impact_horizon, dir.catalyst_type, + dir.key_facts, dir.risks, dir.evidence_spans, + c.legal_name + FROM document_impact_records dir + JOIN companies c ON c.id = dir.company_id + WHERE dir.intelligence_id = $1""", + intel_row["id"], + ) + for imp in impacts: + imp_dict = _row_to_dict(imp) + for jf in ("key_facts", "risks", "evidence_spans"): + imp_dict[jf] = _parse_jsonb(imp_dict.get(jf)) + ev_dict["company_impacts"].append(imp_dict) + + evidence_list.append(ev_dict) + + return { + "trend": trend_dict, + "evidence": evidence_list, + } + + +# --------------------------------------------------------------------------- +# Order History (Requirement 11.1, 11.3) +# --------------------------------------------------------------------------- + +@app.get("/api/orders") +async def list_orders( + ticker: Optional[str] = None, + status: Optional[str] = None, + side: Optional[str] = None, + since: Optional[str] = None, + limit: int = Query(default=50, le=200), + offset: int = 0, +): + """List orders with optional filters.""" + conditions: list[str] = [] + params: list[Any] = [] + idx = 1 + + if ticker: + conditions.append(f"o.ticker = ${idx}") + params.append(ticker.upper()) + idx += 1 + if status: + conditions.append(f"o.status = ${idx}") + params.append(status) + idx += 1 + if side: + conditions.append(f"o.side = ${idx}") + params.append(side) + idx += 1 + if since: + conditions.append(f"o.created_at >= ${idx}::timestamptz") + params.append(since) + idx += 1 + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + + rows = await pool.fetch( + f"""SELECT o.id, o.recommendation_id, o.broker_account_id, o.ticker, + o.side, o.order_type, o.quantity, o.limit_price, o.stop_price, + o.status, o.broker_order_id, o.submitted_at, o.acknowledged_at, + o.filled_at, o.cancelled_at, o.rejected_at, o.rejection_reason, + o.fill_price, o.fill_quantity, o.created_at + FROM orders o + {where} + ORDER BY o.created_at DESC + LIMIT ${idx} OFFSET ${idx + 1}""", + *params, limit, offset, + ) + return [_row_to_dict(r) for r in rows] + + +@app.get("/api/orders/{order_id}") +async def get_order(order_id: str): + """Get a single order with its events, decision trace, and full audit trail. + + Requirement 11.3: expose full audit trail from ingestion through broker + execution and eventual market outcome. + """ + row = await pool.fetchrow( + """SELECT o.id, o.recommendation_id, o.broker_account_id, o.ticker, + o.side, o.order_type, o.quantity, o.limit_price, o.stop_price, + o.status, o.idempotency_key, o.broker_order_id, + o.decision_trace, o.submitted_at, o.acknowledged_at, + o.filled_at, o.cancelled_at, o.rejected_at, o.rejection_reason, + o.fill_price, o.fill_quantity, o.created_at, o.updated_at + FROM orders o WHERE o.id = $1""", + order_id, + ) + if not row: + raise HTTPException(404, "Order not found") + + result = _row_to_dict(row) + result["decision_trace"] = _parse_jsonb(result.get("decision_trace")) + + # Order events + events = await pool.fetch( + """SELECT id, event_type, data, broker_timestamp, created_at + FROM order_events WHERE order_id = $1 ORDER BY created_at ASC""", + order_id, + ) + result["events"] = [] + for ev in events: + ev_dict = _row_to_dict(ev) + ev_dict["data"] = _parse_jsonb(ev_dict.get("data")) + result["events"].append(ev_dict) + + # Full audit trail (Requirement 11.3) + recommendation_id = str(row["recommendation_id"]) if row["recommendation_id"] else None + result["audit_trail"] = await get_order_audit_trail(pool, order_id, recommendation_id) + + return result + + +# --------------------------------------------------------------------------- +# Positions (Requirement 11.1) +# --------------------------------------------------------------------------- + +@app.get("/api/positions") +async def list_positions( + ticker: Optional[str] = None, +): + """List current positions.""" + if ticker: + rows = await pool.fetch( + """SELECT p.id, p.broker_account_id, p.ticker, p.quantity, + p.avg_entry_price, p.current_price, + p.unrealized_pnl, p.realized_pnl, p.updated_at + FROM positions p WHERE p.ticker = $1 ORDER BY p.ticker""", + ticker.upper(), + ) + else: + rows = await pool.fetch( + """SELECT p.id, p.broker_account_id, p.ticker, p.quantity, + p.avg_entry_price, p.current_price, + p.unrealized_pnl, p.realized_pnl, p.updated_at + FROM positions p ORDER BY p.ticker""", + ) + return [_row_to_dict(r) for r in rows] + + +# --------------------------------------------------------------------------- +# Audit Trail (Requirement 11.3) +# --------------------------------------------------------------------------- + +@app.get("/api/audit/{entity_type}/{entity_id}") +async def get_audit_trail(entity_type: str, entity_id: str): + """Get audit events for any entity type and ID.""" + events = await get_entity_audit_trail(pool, entity_type, entity_id) + if not events: + raise HTTPException(404, "No audit events found") + return events + + +# --------------------------------------------------------------------------- +# Admin: Source Health (Requirement 11.1 - source health) +# --------------------------------------------------------------------------- + +@app.get("/api/admin/sources/health") +async def get_source_health( + source_type: Optional[str] = None, + company_id: Optional[str] = None, + active_only: bool = True, +): + """Source health overview: each source with its latest ingestion status and failure counts. + + Design: Section 9.1 (source health and job state) + """ + conditions = [] + params: list[Any] = [] + idx = 1 + + if active_only: + conditions.append(f"s.active = ${idx}") + params.append(True) + idx += 1 + if source_type: + conditions.append(f"s.source_type = ${idx}") + params.append(source_type) + idx += 1 + if company_id: + conditions.append(f"s.company_id = ${idx}") + params.append(company_id) + idx += 1 + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + + rows = await pool.fetch( + f"""SELECT s.id AS source_id, s.source_type, s.source_name, + s.credibility_score, s.active, + c.ticker, c.legal_name, c.id AS company_id, + latest.status AS last_run_status, + latest.started_at AS last_run_at, + latest.error_message AS last_error, + latest.items_fetched AS last_items_fetched, + latest.items_new AS last_items_new, + COALESCE(stats.total_runs, 0) AS total_runs_24h, + COALESCE(stats.failed_runs, 0) AS failed_runs_24h, + COALESCE(stats.total_items, 0) AS total_items_24h + FROM sources s + JOIN companies c ON c.id = s.company_id + LEFT JOIN LATERAL ( + SELECT ir.status, ir.started_at, ir.error_message, + ir.items_fetched, ir.items_new + FROM ingestion_runs ir + WHERE ir.source_id = s.id + ORDER BY ir.started_at DESC + LIMIT 1 + ) latest ON TRUE + LEFT JOIN LATERAL ( + SELECT COUNT(*) AS total_runs, + COUNT(*) FILTER (WHERE ir2.status = 'failed') AS failed_runs, + COALESCE(SUM(ir2.items_fetched), 0) AS total_items + FROM ingestion_runs ir2 + WHERE ir2.source_id = s.id + AND ir2.started_at >= NOW() - INTERVAL '24 hours' + ) stats ON TRUE + {where} + ORDER BY c.ticker, s.source_type""", + *params, + ) + return [_row_to_dict(r) for r in rows] + + +@app.get("/api/admin/sources/{source_id}/runs") +async def get_source_runs( + source_id: str, + limit: int = Query(default=20, le=100), + offset: int = 0, +): + """Recent ingestion runs for a specific source.""" + rows = await pool.fetch( + """SELECT id, source_id, company_id, source_type, status, + started_at, completed_at, items_fetched, items_new, + error_message, retry_count, next_retry_at + FROM ingestion_runs + WHERE source_id = $1 + ORDER BY started_at DESC + LIMIT $2 OFFSET $3""", + source_id, limit, offset, + ) + return [_row_to_dict(r) for r in rows] + + +@app.put("/api/admin/sources/{source_id}/toggle") +async def toggle_source(source_id: str, active: bool = True): + """Enable or disable a source.""" + row = await pool.fetchrow( + """UPDATE sources SET active = $2, updated_at = NOW() + WHERE id = $1 + RETURNING id, source_type, source_name, active""", + source_id, active, + ) + if not row: + raise HTTPException(404, "Source not found") + return _row_to_dict(row) + + +@app.put("/api/admin/sources/{source_id}/credibility") +async def update_source_credibility(source_id: str, credibility_score: float = Query(ge=0.0, le=1.0)): + """Update a source's credibility score.""" + row = await pool.fetchrow( + """UPDATE sources SET credibility_score = $2, updated_at = NOW() + WHERE id = $1 + RETURNING id, source_type, source_name, credibility_score""", + source_id, credibility_score, + ) + if not row: + raise HTTPException(404, "Source not found") + return _row_to_dict(row) + + +# --------------------------------------------------------------------------- +# Admin: Symbol Configs (Requirement 11.1 - symbol configs) +# --------------------------------------------------------------------------- + +@app.put("/api/admin/companies/{company_id}/toggle") +async def toggle_company(company_id: str, active: bool = True): + """Enable or disable a tracked company.""" + row = await pool.fetchrow( + """UPDATE companies SET active = $2, updated_at = NOW() + WHERE id = $1 + RETURNING id, ticker, legal_name, active""", + company_id, active, + ) + if not row: + raise HTTPException(404, "Company not found") + return _row_to_dict(row) + + +@app.put("/api/admin/companies/{company_id}/sector") +async def update_company_sector( + company_id: str, + sector: str = Query(...), + industry: Optional[str] = None, +): + """Update a company's sector and industry classification.""" + if industry is not None: + row = await pool.fetchrow( + """UPDATE companies SET sector = $2, industry = $3, updated_at = NOW() + WHERE id = $1 + RETURNING id, ticker, legal_name, sector, industry""", + company_id, sector, industry, + ) + else: + row = await pool.fetchrow( + """UPDATE companies SET sector = $2, updated_at = NOW() + WHERE id = $1 + RETURNING id, ticker, legal_name, sector, industry""", + company_id, sector, + ) + if not row: + raise HTTPException(404, "Company not found") + return _row_to_dict(row) + + +@app.get("/api/admin/companies/coverage") +async def get_symbol_coverage(): + """Overview of source coverage per active company. + + Shows how many active sources of each type are configured per symbol, + useful for identifying coverage gaps. + """ + rows = await pool.fetch( + """SELECT c.id AS company_id, c.ticker, c.legal_name, c.sector, + c.active, + COUNT(s.id) FILTER (WHERE s.active) AS active_sources, + COUNT(s.id) FILTER (WHERE s.source_type = 'market_api' AND s.active) AS market_sources, + COUNT(s.id) FILTER (WHERE s.source_type = 'news_api' AND s.active) AS news_sources, + COUNT(s.id) FILTER (WHERE s.source_type = 'filings_api' AND s.active) AS filings_sources, + COUNT(s.id) FILTER (WHERE s.source_type = 'web_scrape' AND s.active) AS web_scrape_sources, + COUNT(s.id) FILTER (WHERE s.source_type = 'broker' AND s.active) AS broker_sources + FROM companies c + LEFT JOIN sources s ON s.company_id = c.id + WHERE c.active = TRUE + GROUP BY c.id, c.ticker, c.legal_name, c.sector, c.active + ORDER BY c.ticker""", + ) + return [_row_to_dict(r) for r in rows] + + +# --------------------------------------------------------------------------- +# Admin: Trading Mode (Requirement 8.1, 8.2, 11.1) +# --------------------------------------------------------------------------- + +@app.get("/api/admin/trading/config") +async def get_trading_config(): + """Get the current active risk/trading configuration.""" + row = await pool.fetchrow( + """SELECT id, name, trading_mode, config, active, created_at, updated_at + FROM risk_configs + WHERE active = TRUE + ORDER BY updated_at DESC + LIMIT 1""", + ) + if not row: + return {"trading_mode": "paper", "config": {}, "message": "No active config found, using defaults"} + + result = _row_to_dict(row) + result["config"] = _parse_jsonb(result.get("config")) + return result + + +@app.put("/api/admin/trading/mode") +async def set_trading_mode(mode: str = Query(..., pattern="^(paper|live|disabled)$")): + """Switch the active trading mode. + + Requirement 8.1: support paper and live as separate execution environments. + Requirement 8.2: live mode requires operator approval controls. + """ + row = await pool.fetchrow( + """UPDATE risk_configs SET trading_mode = $1, updated_at = NOW() + WHERE active = TRUE + RETURNING id, name, trading_mode""", + mode, + ) + if not row: + # No active config exists yet — create one with the requested mode + row = await pool.fetchrow( + """INSERT INTO risk_configs (name, trading_mode, config, active) + VALUES ('default', $1, '{}', TRUE) + RETURNING id, name, trading_mode""", + mode, + ) + return _row_to_dict(row) + + +@app.put("/api/admin/trading/config") +async def update_trading_config(config: dict[str, Any]): + """Update the active risk configuration JSON. + + Accepts a partial or full risk config object. The config is stored + as JSONB alongside the trading_mode in risk_configs. + """ + config_json = json.dumps(config) + + row = await pool.fetchrow( + """UPDATE risk_configs SET config = $1::jsonb, updated_at = NOW() + WHERE active = TRUE + RETURNING id, name, trading_mode, config""", + config_json, + ) + if not row: + row = await pool.fetchrow( + """INSERT INTO risk_configs (name, trading_mode, config, active) + VALUES ('default', 'paper', $1::jsonb, TRUE) + RETURNING id, name, trading_mode, config""", + config_json, + ) + result = _row_to_dict(row) + result["config"] = _parse_jsonb(result.get("config")) + return result + + +@app.get("/api/admin/trading/approvals") +async def list_pending_approvals(): + """List pending operator approval requests for live trading orders.""" + rows = await pool.fetch( + """SELECT id, order_job, recommendation_id, ticker, side, quantity, + estimated_value, status, risk_evaluation_id, requested_by, + reviewed_by, review_note, expires_at, requested_at, reviewed_at + FROM operator_approvals + WHERE status = 'pending' + ORDER BY requested_at ASC""", + ) + results = [] + for r in rows: + d = _row_to_dict(r) + d["order_job"] = _parse_jsonb(d.get("order_job")) + results.append(d) + return results + + +@app.put("/api/admin/trading/approvals/{approval_id}") +async def review_approval_request( + approval_id: str, + approved: bool = Query(...), + reviewed_by: str = "operator", + review_note: str = "", +): + """Approve or reject a pending operator approval request. + + Requirement 8.2: live orders require operator approval controls. + """ + now = datetime.now(timezone.utc) + new_status = "approved" if approved else "rejected" + + row = await pool.fetchrow( + """UPDATE operator_approvals + SET status = $2, reviewed_by = $3, review_note = $4, + reviewed_at = $5, updated_at = NOW() + WHERE id = $1::uuid AND status = 'pending' + RETURNING id, ticker, status, reviewed_by""", + approval_id, new_status, reviewed_by, review_note, now, + ) + if not row: + raise HTTPException(404, "Approval not found or no longer pending") + return _row_to_dict(row) + + +@app.get("/api/admin/trading/lockouts") +async def list_active_lockouts(): + """List active symbol lockouts (news-shock, cooldown).""" + rows = await pool.fetch( + """SELECT id, ticker, lockout_type, reason, expires_at, created_at + FROM symbol_lockouts + WHERE expires_at > NOW() + ORDER BY expires_at ASC""", + ) + return [_row_to_dict(r) for r in rows] + + +# --------------------------------------------------------------------------- +# Operational Dashboard (Requirement 12.1, 12.2, 12.3) +# --------------------------------------------------------------------------- + +@app.get("/api/ops/ingestion/throughput") +async def get_ingestion_throughput( + hours: int = Query(default=24, ge=1, le=168), + bucket: str = Query(default="1h", pattern="^(15m|1h|6h|1d)$"), +): + """Ingestion throughput over time, bucketed by interval. + + Returns document counts and item counts per time bucket, broken down + by source type. Powers the ingestion throughput chart. + + Requirements: 12.1, 12.3 + """ + bucket_interval = { + "15m": "15 minutes", + "1h": "1 hour", + "6h": "6 hours", + "1d": "1 day", + }[bucket] + + rows = await pool.fetch( + f"""SELECT + date_trunc('hour', ir.started_at) + - (EXTRACT(minute FROM ir.started_at)::int + % EXTRACT(epoch FROM INTERVAL '{bucket_interval}')::int / 60) + * INTERVAL '1 minute' AS bucket_start, + ir.source_type, + COUNT(*) AS run_count, + COUNT(*) FILTER (WHERE ir.status = 'completed') AS completed, + COUNT(*) FILTER (WHERE ir.status = 'failed') AS failed, + COALESCE(SUM(ir.items_fetched), 0) AS items_fetched, + COALESCE(SUM(ir.items_new), 0) AS items_new + FROM ingestion_runs ir + WHERE ir.started_at >= NOW() - INTERVAL '1 hour' * $1 + GROUP BY bucket_start, ir.source_type + ORDER BY bucket_start DESC, ir.source_type""", + hours, + ) + return [_row_to_dict(r) for r in rows] + + +@app.get("/api/ops/ingestion/summary") +async def get_ingestion_summary( + hours: int = Query(default=24, ge=1, le=168), +): + """High-level ingestion summary for the operational dashboard. + + Returns total runs, success/failure counts, items processed, and + per-source-type breakdown for the given time window. + + Requirements: 12.1 + """ + row = await pool.fetchrow( + """SELECT + COUNT(*) AS total_runs, + COUNT(*) FILTER (WHERE status = 'completed') AS completed, + COUNT(*) FILTER (WHERE status = 'failed') AS failed, + COUNT(*) FILTER (WHERE status = 'pending') AS pending, + COUNT(*) FILTER (WHERE status = 'running') AS running, + COALESCE(SUM(items_fetched), 0) AS total_items_fetched, + COALESCE(SUM(items_new), 0) AS total_items_new, + COUNT(DISTINCT source_id) AS active_sources, + COUNT(DISTINCT company_id) AS active_companies + FROM ingestion_runs + WHERE started_at >= NOW() - INTERVAL '1 hour' * $1""", + hours, + ) + + by_type = await pool.fetch( + """SELECT + source_type, + COUNT(*) AS runs, + COUNT(*) FILTER (WHERE status = 'completed') AS completed, + COUNT(*) FILTER (WHERE status = 'failed') AS failed, + COALESCE(SUM(items_fetched), 0) AS items_fetched, + COALESCE(SUM(items_new), 0) AS items_new + FROM ingestion_runs + WHERE started_at >= NOW() - INTERVAL '1 hour' * $1 + GROUP BY source_type + ORDER BY runs DESC""", + hours, + ) + + result = _row_to_dict(row) if row else {} + result["by_source_type"] = [_row_to_dict(r) for r in by_type] + result["hours"] = hours + return result + + +@app.get("/api/ops/model/failures") +async def get_model_failures( + hours: int = Query(default=24, ge=1, le=168), + limit: int = Query(default=50, le=200), +): + """Recent model extraction failures with error details. + + Returns individual failed extraction attempts for debugging. + + Requirements: 12.2 + """ + rows = await pool.fetch( + """SELECT + mpm.id, mpm.document_id, mpm.ticker, mpm.model_name, + mpm.prompt_version, mpm.schema_version, + mpm.attempt_count, mpm.total_duration_ms, + mpm.validation_status, mpm.validation_error_count, + mpm.validation_errors, mpm.retry_count, + mpm.confidence, mpm.recorded_at, + d.title AS document_title, d.document_type, d.source_type + FROM model_performance_metrics mpm + LEFT JOIN documents d ON d.id = mpm.document_id + WHERE mpm.success = FALSE + AND mpm.recorded_at >= NOW() - INTERVAL '1 hour' * $1 + ORDER BY mpm.recorded_at DESC + LIMIT $2""", + hours, limit, + ) + results = [] + for r in rows: + d = _row_to_dict(r) + d["validation_errors"] = _parse_jsonb(d.get("validation_errors")) + results.append(d) + return results + + +@app.get("/api/ops/model/performance") +async def get_model_performance( + hours: int = Query(default=24, ge=1, le=168), + model_name: Optional[str] = None, +): + """Aggregated model performance metrics for the operational dashboard. + + Returns success rate, latency percentiles, retry rate, confidence + distribution, and token usage for the given time window. + + Requirements: 12.2 + """ + return await get_model_performance_summary( + pool, + model_name=model_name, + hours=hours, + ) + + +@app.get("/api/ops/pipeline/health") +async def get_pipeline_health( + hours: int = Query(default=24, ge=1, le=168), +): + """Pipeline stage health summary across ingestion, parsing, extraction, and aggregation. + + Shows document counts at each processing stage and identifies bottlenecks. + + Requirements: 12.1 + """ + # Document status distribution (pipeline stages) + doc_stages = await pool.fetch( + """SELECT + status, + COUNT(*) AS doc_count + FROM documents + WHERE created_at >= NOW() - INTERVAL '1 hour' * $1 + GROUP BY status + ORDER BY doc_count DESC""", + hours, + ) + + # Parsing quality distribution + parse_quality = await pool.fetchrow( + """SELECT + COUNT(*) AS total_parsed, + COUNT(*) FILTER (WHERE parse_confidence = 'high') AS high_confidence, + COUNT(*) FILTER (WHERE parse_confidence = 'medium') AS medium_confidence, + COUNT(*) FILTER (WHERE parse_confidence = 'low') AS low_confidence, + COUNT(*) FILTER (WHERE parse_confidence = 'unknown' OR parse_confidence IS NULL) AS unknown_confidence, + ROUND(AVG(parse_quality_score)::numeric, 3) AS avg_quality_score + FROM documents + WHERE created_at >= NOW() - INTERVAL '1 hour' * $1 + AND status IN ('parsed', 'extracted', 'aggregated')""", + hours, + ) + + # Extraction validation distribution + extraction_stats = await pool.fetchrow( + """SELECT + COUNT(*) AS total_extractions, + COUNT(*) FILTER (WHERE validation_status = 'valid') AS valid, + COUNT(*) FILTER (WHERE validation_status = 'failed') AS failed, + COUNT(*) FILTER (WHERE validation_status = 'pending') AS pending, + ROUND(AVG(confidence)::numeric, 3) AS avg_confidence, + ROUND(AVG(retry_count)::numeric, 2) AS avg_retries + FROM document_intelligence + WHERE created_at >= NOW() - INTERVAL '1 hour' * $1""", + hours, + ) + + # Aggregation output (trend windows generated) + trend_stats = await pool.fetchrow( + """SELECT + COUNT(*) AS trends_generated, + COUNT(DISTINCT entity_id) AS symbols_covered, + ROUND(AVG(confidence)::numeric, 3) AS avg_trend_confidence, + ROUND(AVG(contradiction_score)::numeric, 3) AS avg_contradiction + FROM trend_windows + WHERE created_at >= NOW() - INTERVAL '1 hour' * $1""", + hours, + ) + + return { + "hours": hours, + "document_stages": [_row_to_dict(r) for r in doc_stages], + "parsing": _row_to_dict(parse_quality) if parse_quality else {}, + "extraction": _row_to_dict(extraction_stats) if extraction_stats else {}, + "aggregation": _row_to_dict(trend_stats) if trend_stats else {}, + } + + +@app.get("/api/ops/sources/coverage-gaps") +async def get_source_coverage_gaps(): + """Identify symbols with missing or insufficient source coverage. + + Returns companies that lack one or more expected source types + (market_api, news_api, filings_api), or have sources that haven't + produced successful ingestion runs recently. + + Requirements: 12.3 + """ + # Companies missing expected source types + missing_types = await pool.fetch( + """SELECT + c.id AS company_id, c.ticker, c.legal_name, c.sector, + ARRAY_AGG(DISTINCT s.source_type) FILTER (WHERE s.active) AS active_types, + ARRAY['market_api', 'news_api', 'filings_api'] AS expected_types + FROM companies c + LEFT JOIN sources s ON s.company_id = c.id AND s.active = TRUE + WHERE c.active = TRUE + GROUP BY c.id, c.ticker, c.legal_name, c.sector + HAVING NOT ARRAY['market_api', 'news_api', 'filings_api'] <@ ARRAY_AGG(DISTINCT s.source_type) FILTER (WHERE s.active) + OR ARRAY_AGG(DISTINCT s.source_type) FILTER (WHERE s.active) IS NULL + ORDER BY c.ticker""", + ) + + # Sources with no successful runs in the last 24 hours + stale_sources = await pool.fetch( + """SELECT + s.id AS source_id, s.source_type, s.source_name, + c.ticker, c.legal_name, + MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') AS last_success, + MAX(ir.started_at) AS last_attempt, + COUNT(*) FILTER (WHERE ir.status = 'failed' + AND ir.started_at >= NOW() - INTERVAL '24 hours') AS recent_failures + FROM sources s + JOIN companies c ON c.id = s.company_id + LEFT JOIN ingestion_runs ir ON ir.source_id = s.id + WHERE s.active = TRUE AND c.active = TRUE + GROUP BY s.id, s.source_type, s.source_name, c.ticker, c.legal_name + HAVING MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') + < NOW() - INTERVAL '24 hours' + OR MAX(ir.started_at) FILTER (WHERE ir.status = 'completed') IS NULL + ORDER BY c.ticker, s.source_type""", + ) + + return { + "missing_source_types": [_row_to_dict(r) for r in missing_types], + "stale_sources": [_row_to_dict(r) for r in stale_sources], + } diff --git a/services/extractor/client.py b/services/extractor/client.py new file mode 100644 index 0000000..98653a5 --- /dev/null +++ b/services/extractor/client.py @@ -0,0 +1,268 @@ +"""Ollama client wrapper using structured output format. + +Sends documents to a local Ollama instance via the /api/chat endpoint +with the ``format`` parameter set to the extraction JSON schema, ensuring +the model returns schema-compliant JSON. + +Includes retry logic for invalid or incomplete model responses with +exponential backoff, error classification, and full audit preservation. + +Requirements: 5.1, 5.2, 5.4 +""" +from __future__ import annotations + +import asyncio +import json +import logging +import time +from dataclasses import dataclass, field + +import httpx + +from services.extractor.prompts import ( + build_extraction_prompt, + get_json_schema, + get_prompt_metadata, +) +from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction +from services.shared.config import OllamaConfig + +logger = logging.getLogger("ollama_client") + +# Errors that should NOT be retried — the request itself is bad. +_NON_RETRYABLE_ERRORS = frozenset({ + "http_400", + "http_401", + "http_403", + "http_404", + "http_422", +}) + + +def _is_retryable(error: str | None) -> bool: + """Determine whether an extraction error warrants a retry.""" + if error is None: + return False + return error not in _NON_RETRYABLE_ERRORS + + +@dataclass +class ExtractionAttempt: + """Record of a single extraction attempt for audit.""" + + raw_output: str = "" + validation: ValidationReport | None = None + error: str | None = None + duration_ms: int = 0 + model: str = "" + retryable: bool = True + + +@dataclass +class ExtractionResponse: + """Full response from an extraction call, including all attempts.""" + + success: bool = False + result: ExtractionResult | None = None + attempts: list[ExtractionAttempt] = field(default_factory=list) + prompt_metadata: dict[str, str] = field(default_factory=dict) + model: str = "" + total_duration_ms: int = 0 + + +def _compute_backoff( + attempt_num: int, + base_delay: float, + max_delay: float, + multiplier: float, +) -> float: + """Compute exponential backoff delay for a given attempt number.""" + delay = base_delay * (multiplier ** attempt_num) + return min(delay, max_delay) + + +class OllamaClient: + """Async client for Ollama structured extraction. + + Usage:: + + config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b") + client = OllamaClient(config) + response = await client.extract( + document_text="Apple reported record earnings...", + document_type="article", + document_id="abc-123", + ) + if response.success: + print(response.result) + """ + + _config: OllamaConfig + _max_retries: int + _base_delay: float + _max_delay: float + _backoff_multiplier: float + _owns_client: bool + _http: httpx.AsyncClient + + def __init__( + self, + config: OllamaConfig, + max_retries: int | None = None, + http_client: httpx.AsyncClient | None = None, + ) -> None: + self._config = config + self._max_retries = max_retries if max_retries is not None else config.max_retries + self._base_delay = config.retry_base_delay + self._max_delay = config.retry_max_delay + self._backoff_multiplier = config.retry_backoff_multiplier + self._owns_client = http_client is None + self._http = http_client or httpx.AsyncClient(timeout=config.timeout) + + async def close(self) -> None: + """Close the underlying HTTP client if we own it.""" + if self._owns_client: + await self._http.aclose() + + async def extract( + self, + document_text: str, + document_type: str = "article", + document_id: str = "", + known_tickers: list[str] | None = None, + ) -> ExtractionResponse: + """Send a document to Ollama for structured intelligence extraction. + + Retries up to ``max_retries`` times when the model returns invalid + or incomplete JSON. Uses exponential backoff between retries. + Non-retryable errors (e.g. HTTP 400) stop retries immediately. + Each attempt and its validation result are preserved for audit. + + Args: + document_text: Normalized text content of the document. + document_type: One of article, filing, transcript, press_release. + document_id: Optional document ID for traceability. + known_tickers: Optional ticker hints for the model. + + Returns: + An ``ExtractionResponse`` with the parsed result on success. + """ + prompts = build_extraction_prompt( + document_text=document_text, + document_type=document_type, + document_id=document_id, + known_tickers=known_tickers, + ) + json_schema = get_json_schema() + prompt_meta = get_prompt_metadata() + + response = ExtractionResponse( + prompt_metadata=prompt_meta, + model=self._config.model, + ) + + total_start = time.monotonic() + + for attempt_num in range(self._max_retries + 1): + attempt = await self._call_ollama(prompts, json_schema, document_text) + response.attempts.append(attempt) + + if attempt.error is None and attempt.validation and attempt.validation.valid: + response.success = True + response.result = attempt.validation.parsed + break + + # Check if the error is non-retryable — stop immediately + if not _is_retryable(attempt.error): + attempt.retryable = False + logger.warning( + "Non-retryable error for doc %s: %s — stopping retries", + document_id or "unknown", + attempt.error, + ) + break + + if attempt_num < self._max_retries: + delay = _compute_backoff( + attempt_num, + self._base_delay, + self._max_delay, + self._backoff_multiplier, + ) + logger.warning( + "Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs", + attempt_num + 1, + self._max_retries + 1, + document_id or "unknown", + attempt.error or "validation failed", + delay, + ) + await asyncio.sleep(delay) + + response.total_duration_ms = int((time.monotonic() - total_start) * 1000) + return response + + async def _call_ollama( + self, + prompts: dict[str, str], + json_schema: dict[str, object], + document_text: str = "", + ) -> ExtractionAttempt: + """Make a single call to the Ollama /api/chat endpoint.""" + attempt = ExtractionAttempt(model=self._config.model) + start = time.monotonic() + + payload = { + "model": self._config.model, + "messages": [ + {"role": "system", "content": prompts["system"]}, + {"role": "user", "content": prompts["user"]}, + ], + "format": json_schema, + "stream": False, + } + + try: + resp = await self._http.post( + f"{self._config.base_url}/api/chat", + json=payload, + ) + _ = resp.raise_for_status() + except httpx.TimeoutException: + attempt.error = "timeout" + attempt.duration_ms = int((time.monotonic() - start) * 1000) + return attempt + except httpx.HTTPStatusError as exc: + attempt.error = f"http_{exc.response.status_code}" + attempt.retryable = _is_retryable(attempt.error) + attempt.duration_ms = int((time.monotonic() - start) * 1000) + return attempt + except httpx.HTTPError as exc: + attempt.error = f"connection_error: {exc}" + attempt.duration_ms = int((time.monotonic() - start) * 1000) + return attempt + + attempt.duration_ms = int((time.monotonic() - start) * 1000) + + # Parse the Ollama response envelope + try: + body: dict[str, object] = resp.json() + except json.JSONDecodeError: + attempt.error = "invalid_response_json" + attempt.raw_output = resp.text + return attempt + + msg = body.get("message") + content: str = msg.get("content", "") if isinstance(msg, dict) else "" + attempt.raw_output = content + + if not content: + attempt.error = "empty_model_response" + return attempt + + # Validate against extraction schema + attempt.validation = validate_extraction(content, document_text=document_text) + if not attempt.validation.valid: + attempt.error = "; ".join(attempt.validation.errors) + + return attempt diff --git a/services/extractor/main.py b/services/extractor/main.py new file mode 100644 index 0000000..c4b0ae6 --- /dev/null +++ b/services/extractor/main.py @@ -0,0 +1,72 @@ +"""Extractor worker entrypoint - polls Redis for extraction jobs.""" +from __future__ import annotations + +import asyncio +import logging + +import asyncpg +from minio import Minio + +from services.extractor.client import OllamaClient +from services.extractor.worker import persist_extraction +from services.shared.config import load_config +from services.shared.logging import setup_logging +from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key + +logger = logging.getLogger("extractor_main") + + +async def main() -> None: + config = load_config() + setup_logging("extractor", level=config.log_level, json_output=config.json_logs) + + pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) + minio_client = Minio( + config.minio.endpoint, + access_key=config.minio.access_key, + secret_key=config.minio.secret_key, + secure=config.minio.secure, + ) + ollama = OllamaClient(config.ollama) + + import json + import redis.asyncio as aioredis + + redis_client = aioredis.from_url(config.redis.url) + queue = queue_key(QUEUE_EXTRACTION) + logger.info("Extractor worker started, polling %s", queue) + + try: + while True: + raw = await redis_client.lpop(queue) + if raw is None: + await asyncio.sleep(1) + continue + + payload = raw + job = json.loads(payload) + document_id = job.get("document_id", "") + ticker = job.get("ticker", "") + text = job.get("text", "") + + logger.info("Processing extraction job for doc %s / %s", document_id, ticker) + + try: + extraction_response = await ollama.extract(text) + await persist_extraction( + pool=pool, + minio_client=minio_client, + document_id=document_id, + ticker=ticker, + extraction_response=extraction_response, + document_text_length=len(text), + ) + except Exception: + logger.exception("Extraction failed for doc %s", document_id) + finally: + await pool.close() + await redis_client.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/services/extractor/metrics.py b/services/extractor/metrics.py new file mode 100644 index 0000000..0258f1b --- /dev/null +++ b/services/extractor/metrics.py @@ -0,0 +1,250 @@ +"""Model performance metrics collection and persistence. + +Tracks extraction success/failure rates, latency percentiles, retry counts, +validation error distributions, confidence scores, and token usage estimates. +Metrics are persisted to PostgreSQL for operational dashboards and published +to the analytical lake for Trino/Superset queries. + +Requirements: 5.2, 5.4, 12.1, 12.2 +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone + +import asyncpg + +from services.extractor.client import ExtractionResponse + +logger = logging.getLogger("extractor_metrics") + +# Rough token estimate: ~4 chars per token for English text +_CHARS_PER_TOKEN = 4 + + +@dataclass +class ExtractionMetrics: + """Metrics extracted from a single extraction run.""" + + document_id: str = "" + ticker: str = "" + model_name: str = "" + prompt_version: str = "" + schema_version: str = "" + success: bool = False + attempt_count: int = 0 + total_duration_ms: int = 0 + first_attempt_duration_ms: int = 0 + final_attempt_duration_ms: int = 0 + confidence: float = 0.0 + validation_status: str = "unknown" + validation_error_count: int = 0 + validation_warning_count: int = 0 + validation_errors: list[str] = field(default_factory=list) + retry_count: int = 0 + input_token_estimate: int = 0 + output_token_estimate: int = 0 + company_count: int = 0 + recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +def collect_metrics( + extraction_response: ExtractionResponse, + *, + document_id: str = "", + ticker: str = "", + document_text_length: int = 0, +) -> ExtractionMetrics: + """Collect metrics from an ExtractionResponse. + + Args: + extraction_response: The full response from OllamaClient.extract(). + document_id: UUID of the source document. + ticker: Primary ticker symbol. + document_text_length: Length of the input document text in characters. + + Returns: + An ExtractionMetrics dataclass with all computed fields. + """ + attempts = extraction_response.attempts + first_dur = attempts[0].duration_ms if attempts else 0 + final_dur = attempts[-1].duration_ms if attempts else 0 + + # Gather validation info from the final attempt + final_attempt = attempts[-1] if attempts else None + val_errors: list[str] = [] + val_warnings: list[str] = [] + if final_attempt and final_attempt.validation: + val_errors = final_attempt.validation.errors + val_warnings = final_attempt.validation.warnings + + # Determine validation status + if extraction_response.success: + validation_status = "valid" + elif attempts: + validation_status = "failed" + else: + validation_status = "unknown" + + # Confidence from the result, or 0 if failed + confidence = 0.0 + company_count = 0 + if extraction_response.result: + confidence = extraction_response.result.confidence + company_count = len(extraction_response.result.companies) + + # Token estimates + input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0 + output_tokens = 0 + if final_attempt and final_attempt.raw_output: + output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN + + return ExtractionMetrics( + document_id=document_id, + ticker=ticker, + model_name=extraction_response.model, + prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""), + schema_version=extraction_response.prompt_metadata.get("schema_version", ""), + success=extraction_response.success, + attempt_count=len(attempts), + total_duration_ms=extraction_response.total_duration_ms, + first_attempt_duration_ms=first_dur, + final_attempt_duration_ms=final_dur, + confidence=confidence, + validation_status=validation_status, + validation_error_count=len(val_errors), + validation_warning_count=len(val_warnings), + validation_errors=val_errors, + retry_count=max(0, len(attempts) - 1), + input_token_estimate=input_tokens, + output_token_estimate=output_tokens, + company_count=company_count, + ) + + +async def persist_metrics( + pool: asyncpg.Pool, + metrics: ExtractionMetrics, +) -> str: + """Persist extraction metrics to the model_performance_metrics table. + + Args: + pool: PostgreSQL connection pool. + metrics: Collected metrics from an extraction run. + + Returns: + The UUID of the inserted metrics row. + """ + row_id = await pool.fetchval( + """INSERT INTO model_performance_metrics + (document_id, ticker, model_name, prompt_version, schema_version, + success, attempt_count, total_duration_ms, + first_attempt_duration_ms, final_attempt_duration_ms, + confidence, validation_status, validation_error_count, + validation_warning_count, validation_errors, retry_count, + input_token_estimate, output_token_estimate, company_count, + recorded_at) + VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, + $11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20) + RETURNING id""", + metrics.document_id, + metrics.ticker, + metrics.model_name, + metrics.prompt_version, + metrics.schema_version, + metrics.success, + metrics.attempt_count, + metrics.total_duration_ms, + metrics.first_attempt_duration_ms, + metrics.final_attempt_duration_ms, + metrics.confidence, + metrics.validation_status, + metrics.validation_error_count, + metrics.validation_warning_count, + json.dumps(metrics.validation_errors), + metrics.retry_count, + metrics.input_token_estimate, + metrics.output_token_estimate, + metrics.company_count, + metrics.recorded_at, + ) + logger.info( + "Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d", + row_id, metrics.document_id, metrics.success, + metrics.total_duration_ms, metrics.retry_count, + ) + return str(row_id) + + +async def get_model_performance_summary( + pool: asyncpg.Pool, + *, + model_name: str | None = None, + hours: int = 24, +) -> dict[str, object]: + """Query aggregated model performance metrics for dashboards. + + Returns a summary dict with success rate, avg latency, retry rate, + confidence distribution, and error breakdown for the given time window. + + Args: + pool: PostgreSQL connection pool. + model_name: Optional filter by model name. + hours: Lookback window in hours (default 24). + + Returns: + Dict with aggregated performance metrics. + """ + model_filter = "AND model_name = $2" if model_name else "" + params: list[object] = [hours] + if model_name: + params.append(model_name) + + row = await pool.fetchrow( + f"""SELECT + COUNT(*) AS total_extractions, + COUNT(*) FILTER (WHERE success) AS successful, + COUNT(*) FILTER (WHERE NOT success) AS failed, + ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms, + ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms, + ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms, + ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms, + ROUND(AVG(retry_count)::numeric, 2) AS avg_retries, + ROUND(AVG(confidence)::numeric, 3) AS avg_confidence, + SUM(input_token_estimate) AS total_input_tokens, + SUM(output_token_estimate) AS total_output_tokens, + ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc, + ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors, + ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings + FROM model_performance_metrics + WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1 + {model_filter}""", + *params, + ) + + if not row or row["total_extractions"] == 0: + return {"total_extractions": 0, "success_rate": 0.0} + + total = row["total_extractions"] + successful = row["successful"] + + return { + "total_extractions": total, + "successful": successful, + "failed": row["failed"], + "success_rate": round(successful / total, 4) if total > 0 else 0.0, + "avg_duration_ms": float(row["avg_duration_ms"] or 0), + "p50_duration_ms": float(row["p50_duration_ms"] or 0), + "p95_duration_ms": float(row["p95_duration_ms"] or 0), + "p99_duration_ms": float(row["p99_duration_ms"] or 0), + "avg_retries": float(row["avg_retries"] or 0), + "avg_confidence": float(row["avg_confidence"] or 0), + "total_input_tokens": int(row["total_input_tokens"] or 0), + "total_output_tokens": int(row["total_output_tokens"] or 0), + "avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0), + "avg_validation_errors": float(row["avg_validation_errors"] or 0), + "avg_validation_warnings": float(row["avg_validation_warnings"] or 0), + "hours": hours, + } diff --git a/services/extractor/prompts.py b/services/extractor/prompts.py new file mode 100644 index 0000000..b62c2df --- /dev/null +++ b/services/extractor/prompts.py @@ -0,0 +1,149 @@ +"""Extraction prompt templates with anti-hallucination instructions. + +Builds structured prompts for Ollama document intelligence extraction. +Each prompt includes the target JSON schema, anti-hallucination rules, +and document-type-specific guidance. + +Requirements: 5.1, 5.2, 5.3, 5.4, 5.5 +""" +from __future__ import annotations + +import json +from typing import Any + +from services.extractor.schemas import generate_json_schema, SCHEMA_VERSION +from services.shared.schemas import ( + DocumentType, +) + +PROMPT_VERSION = "document-intel-v1" + +# --- JSON schema for structured output (generated from Pydantic models) --- + +EXTRACTION_JSON_SCHEMA: dict[str, Any] = generate_json_schema() + +# --- Anti-hallucination system prompt --- + +SYSTEM_PROMPT = """\ +You are a financial document analysis system. You extract structured intelligence \ +from financial documents into JSON. + +STRICT RULES — VIOLATIONS WILL INVALIDATE YOUR OUTPUT: + +1. ONLY extract information explicitly stated in the document text provided. +2. NEVER fabricate facts, quotes, numbers, dates, or company names. +3. NEVER infer information that is not directly supported by the text. +4. If the document does not mention a company, do NOT include that company. +5. If the document is ambiguous about sentiment or impact, use "neutral" or "mixed" \ +and set confidence lower. +6. evidence_spans MUST be short verbatim quotes copied from the document. \ +Do NOT paraphrase or invent quotes. +7. key_facts MUST be directly stated in the document. Do NOT add external knowledge. +8. If you are uncertain about any field, lower the confidence score and add a warning \ +to extraction_warnings. +9. If the document text is too short, garbled, or uninformative, return an empty \ +companies array, set confidence below 0.3, and add "insufficient_content" to warnings. +10. Return ONLY valid JSON matching the provided schema. No commentary, no markdown fences.""" + +# --- Document-type-specific guidance --- + +_DOCTYPE_GUIDANCE: dict[str, str] = { + DocumentType.ARTICLE: ( + "This is a news article. Focus on reported facts, quoted sources, and stated " + "analyst opinions. Distinguish between the journalist's framing and actual " + "company developments. Do not treat speculative language as confirmed fact." + ), + DocumentType.FILING: ( + "This is a regulatory filing (e.g. SEC 10-K, 10-Q, 8-K). Extract concrete " + "financial figures, risk factors, and material events as stated. Filings use " + "precise legal language — preserve that precision in your extraction." + ), + DocumentType.TRANSCRIPT: ( + "This is an earnings call or event transcript. Distinguish between management " + "forward-looking statements and reported results. Flag forward-looking language " + "as lower confidence. Extract specific guidance numbers when stated." + ), + DocumentType.PRESS_RELEASE: ( + "This is a company press release. Be aware that press releases are promotional. " + "Extract stated facts and figures but note that sentiment may be biased positive. " + "Look for concrete metrics rather than marketing language." + ), +} + + +def _get_doctype_guidance(document_type: str) -> str: + """Return document-type-specific extraction guidance.""" + return _DOCTYPE_GUIDANCE.get(document_type, _DOCTYPE_GUIDANCE[DocumentType.ARTICLE]) + + +# --- Prompt builder --- + +def build_extraction_prompt( + document_text: str, + document_type: str = DocumentType.ARTICLE, + known_tickers: list[str] | None = None, + document_id: str = "", +) -> dict[str, str]: + """Build system and user prompts for Ollama structured extraction. + + Args: + document_text: Normalized text content of the document. + document_type: One of the DocumentType enum values. + known_tickers: Optional list of tickers the document may reference. + Helps the model focus but does NOT mean all tickers are relevant. + document_id: Optional document ID for traceability. + + Returns: + Dict with 'system' and 'user' prompt strings. + """ + doctype_guidance = _get_doctype_guidance(document_type) + + ticker_hint = "" + if known_tickers: + tickers_str = ", ".join(known_tickers) + ticker_hint = ( + f"\nThe following tickers may be referenced in this document: {tickers_str}\n" + "Only include a ticker in your output if the document actually discusses that company. " + "Do NOT include a ticker just because it appears in this hint." + ) + + schema_str = json.dumps(EXTRACTION_JSON_SCHEMA, indent=2) + + doc_id_line = f"Document ID: {document_id}\n" if document_id else "" + + user_prompt = f"""\ +Extract structured intelligence from the following document. + +{doc_id_line}Document type: {document_type} +{doctype_guidance} +{ticker_hint} +Your output MUST be a single JSON object conforming to this schema: +{schema_str} + +REMEMBER: +- Only extract what is explicitly in the text below. +- evidence_spans must be verbatim quotes from the text. +- If the text is insufficient, return empty companies and low confidence. +- Return ONLY the JSON object. No other text. + +--- DOCUMENT TEXT --- +{document_text} +--- END DOCUMENT TEXT ---""" + + return { + "system": SYSTEM_PROMPT, + "user": user_prompt, + } + + +def get_prompt_metadata() -> dict[str, str]: + """Return metadata about the current prompt version for audit trails.""" + return { + "prompt_version": PROMPT_VERSION, + "schema_version": SCHEMA_VERSION, + } + + +def get_json_schema() -> dict[str, Any]: + """Return the extraction JSON schema for Ollama structured output format parameter.""" + return EXTRACTION_JSON_SCHEMA diff --git a/services/extractor/replay.py b/services/extractor/replay.py new file mode 100644 index 0000000..011391d --- /dev/null +++ b/services/extractor/replay.py @@ -0,0 +1,250 @@ +"""Replay dataset loader and runner for deterministic extraction testing. + +Loads archived document fixtures from JSON files, validates their expected +extraction outputs against the current schema, and provides a runner that +can compare live Ollama extraction results against expected baselines. + +This enables: +- Schema regression testing: verify expected outputs still pass validation +- Prompt regression testing: detect drift when prompts or schemas change +- End-to-end replay: run fixtures through a live Ollama and compare + +Requirements: 5.1, 5.2, 5.3, 5.4, 5.5 +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from services.extractor.schemas import ( + ExtractionResult, + ValidationReport, + get_schema_version, + validate_extraction, +) + +logger = logging.getLogger("extractor_replay") + +FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures" + + +@dataclass +class ReplayFixture: + """A single replay fixture loaded from disk.""" + + document_id: str + document_type: str + document_text: str + known_tickers: list[str] + expected_extraction: dict[str, Any] + metadata: dict[str, str] + source_path: str = "" + + @property + def expected_result(self) -> ExtractionResult: + """Parse expected_extraction into a validated ExtractionResult.""" + return ExtractionResult.model_validate(self.expected_extraction) + + +@dataclass +class ReplayValidationResult: + """Result of validating a single fixture against the current schema.""" + + fixture_id: str + schema_valid: bool = False + validation_report: ValidationReport | None = None + schema_version: str = "" + error: str | None = None + + +@dataclass +class ReplayComparisonResult: + """Result of comparing a live extraction against the expected baseline.""" + + fixture_id: str + expected_companies: list[str] = field(default_factory=list) + actual_companies: list[str] = field(default_factory=list) + companies_match: bool = False + expected_sentiment_map: dict[str, str] = field(default_factory=dict) + actual_sentiment_map: dict[str, str] = field(default_factory=dict) + sentiment_match: bool = False + expected_catalyst_map: dict[str, str] = field(default_factory=dict) + actual_catalyst_map: dict[str, str] = field(default_factory=dict) + catalyst_match: bool = False + actual_schema_valid: bool = False + warnings: list[str] = field(default_factory=list) + + +def load_fixture(path: Path) -> ReplayFixture: + """Load a single replay fixture from a JSON file. + + Args: + path: Path to the fixture JSON file. + + Returns: + A ReplayFixture with all fields populated. + + Raises: + ValueError: If the fixture is missing required fields. + json.JSONDecodeError: If the file is not valid JSON. + """ + with open(path) as f: + data = json.load(f) + + required = {"document_id", "document_type", "document_text", "expected_extraction"} + missing = required - set(data.keys()) + if missing: + raise ValueError(f"Fixture {path.name} missing required fields: {missing}") + + return ReplayFixture( + document_id=data["document_id"], + document_type=data["document_type"], + document_text=data["document_text"], + known_tickers=data.get("known_tickers", []), + expected_extraction=data["expected_extraction"], + metadata=data.get("metadata", {}), + source_path=str(path), + ) + + +def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]: + """Load all replay fixtures from the fixtures directory. + + Args: + fixtures_dir: Override path to fixtures directory. + Defaults to tests/replay_fixtures/. + + Returns: + List of loaded ReplayFixture objects, sorted by document_id. + """ + directory = fixtures_dir or FIXTURES_DIR + if not directory.is_dir(): + logger.warning("Fixtures directory not found: %s", directory) + return [] + + fixtures: list[ReplayFixture] = [] + for path in sorted(directory.glob("*.json")): + try: + fixture = load_fixture(path) + fixtures.append(fixture) + except (ValueError, json.JSONDecodeError) as exc: + logger.warning("Skipping invalid fixture %s: %s", path.name, exc) + + logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory) + return fixtures + + +def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult: + """Validate a fixture's expected extraction against the current schema. + + This is the core deterministic test: the expected output must still + pass schema and semantic validation with the current code. If it + doesn't, either the fixture is stale or the schema has regressed. + + Args: + fixture: The replay fixture to validate. + + Returns: + A ReplayValidationResult indicating pass/fail. + """ + result = ReplayValidationResult( + fixture_id=fixture.document_id, + schema_version=get_schema_version(), + ) + + try: + report = validate_extraction( + fixture.expected_extraction, + document_text=fixture.document_text, + ) + result.validation_report = report + result.schema_valid = report.valid + except Exception as exc: # noqa: BLE001 + result.error = str(exc) + result.schema_valid = False + + return result + + +def validate_all_fixtures( + fixtures_dir: Path | None = None, +) -> list[ReplayValidationResult]: + """Load and validate all fixtures against the current schema. + + Args: + fixtures_dir: Override path to fixtures directory. + + Returns: + List of validation results, one per fixture. + """ + fixtures = load_all_fixtures(fixtures_dir) + return [validate_fixture(f) for f in fixtures] + + +def compare_extraction( + fixture: ReplayFixture, + actual_result: ExtractionResult, +) -> ReplayComparisonResult: + """Compare a live extraction result against the fixture's expected output. + + Checks structural alignment (same companies detected, same sentiments, + same catalyst types) rather than exact string equality, since LLM + outputs vary in wording across runs. + + Args: + fixture: The replay fixture with expected output. + actual_result: The ExtractionResult from a live extraction. + + Returns: + A ReplayComparisonResult with match details. + """ + expected = fixture.expected_result + comparison = ReplayComparisonResult(fixture_id=fixture.document_id) + + # Company ticker sets + comparison.expected_companies = sorted(c.ticker for c in expected.companies) + comparison.actual_companies = sorted(c.ticker for c in actual_result.companies) + comparison.companies_match = ( + set(comparison.expected_companies) == set(comparison.actual_companies) + ) + + # Sentiment by ticker + comparison.expected_sentiment_map = { + c.ticker: c.sentiment for c in expected.companies + } + comparison.actual_sentiment_map = { + c.ticker: c.sentiment for c in actual_result.companies + } + comparison.sentiment_match = ( + comparison.expected_sentiment_map == comparison.actual_sentiment_map + ) + + # Catalyst type by ticker + comparison.expected_catalyst_map = { + c.ticker: c.catalyst_type for c in expected.companies + } + comparison.actual_catalyst_map = { + c.ticker: c.catalyst_type for c in actual_result.companies + } + comparison.catalyst_match = ( + comparison.expected_catalyst_map == comparison.actual_catalyst_map + ) + + # Schema validity of actual result + actual_report = validate_extraction( + actual_result.model_dump(mode="json"), + document_text=fixture.document_text, + ) + comparison.actual_schema_valid = actual_report.valid + if actual_report.warnings: + comparison.warnings = actual_report.warnings + + if not comparison.companies_match: + comparison.warnings.append( + f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}" + ) + + return comparison diff --git a/services/extractor/schemas.py b/services/extractor/schemas.py new file mode 100644 index 0000000..75ff407 --- /dev/null +++ b/services/extractor/schemas.py @@ -0,0 +1,316 @@ +"""JSON schema definitions for document intelligence extraction. + +Generates Ollama-compatible JSON schemas from Pydantic models so the +extraction contract stays in sync with the shared data models. Also +provides schema validation and semantic validation helpers. + +Requirements: 5.1, 5.2, 5.3, 5.4, 5.5 +""" +from __future__ import annotations + +import json +import re +from typing import Any + +from pydantic import BaseModel, Field + +from services.shared.schemas import ( + CatalystType, + Sentiment, +) + +SCHEMA_VERSION = "2.0.0" + + +# --------------------------------------------------------------------------- +# Pydantic model that mirrors the Ollama extraction output contract. +# This is the *response* shape we ask the model to produce — it intentionally +# omits server-side fields like document_id, source_credibility, and model +# metadata that are attached after extraction. +# --------------------------------------------------------------------------- + + +class CompanyExtractionItem(BaseModel): + """Per-company extraction output expected from the model. + + All fields are required (no defaults) so the generated JSON schema + forces the model to produce every field explicitly. + """ + + ticker: str = Field(description="Stock ticker symbol mentioned in the document.") + company_name: str = Field(description="Full company name as referenced in the document.") + relevance: float = Field( + ge=0, + le=1, + description="How relevant the document is to this company. 0=tangential, 1=primary subject.", + ) + sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.") + impact_score: float = Field( + ge=0, + le=1, + description="Estimated magnitude of impact. 0=negligible, 1=highly material.", + ) + impact_horizon: str = Field( + description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus", + ) + catalyst_type: CatalystType = Field(description="Primary catalyst category.") + key_facts: list[str] = Field( + description="Facts explicitly stated in the document. Do NOT infer or fabricate.", + ) + risks: list[str] = Field( + description="Risks explicitly mentioned in the document.", + ) + evidence_spans: list[str] = Field( + description="Short verbatim quotes from the document supporting the analysis.", + ) + + +class ExtractionResult(BaseModel): + """Top-level structured output the model must return. + + All fields are required (no defaults) so the generated JSON schema + forces the model to produce every field explicitly. + """ + + summary: str = Field( + description="A concise 1-3 sentence summary of the document's main point.", + ) + companies: list[CompanyExtractionItem] = Field( + description="Per-company intelligence extracted from the document.", + ) + macro_themes: list[str] = Field( + description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).", + ) + novelty_score: float = Field( + ge=0, + le=1, + description="How novel or surprising the information is. 0=routine, 1=highly novel.", + ) + confidence: float = Field( + ge=0, + le=1, + description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.", + ) + extraction_warnings: list[str] = Field( + description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.", + ) + + +# --------------------------------------------------------------------------- +# Schema generation +# --------------------------------------------------------------------------- + + +def generate_json_schema() -> dict[str, Any]: + """Generate the JSON schema from the Pydantic model. + + Returns a plain JSON Schema dict suitable for Ollama's ``format`` + parameter. Pydantic ``$defs`` are inlined so the schema is + self-contained. + """ + raw = ExtractionResult.model_json_schema() + # Inline $defs so the schema is flat and Ollama-friendly + return _inline_defs(raw) + + +def get_schema_version() -> str: + """Return the current schema version string.""" + return SCHEMA_VERSION + + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + + +class ValidationReport(BaseModel): + """Result of validating a raw model response.""" + + valid: bool = False + errors: list[str] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + parsed: ExtractionResult | None = None + + +def validate_extraction( + raw_json: str | dict[str, Any], + *, + document_text: str = "", +) -> ValidationReport: + """Validate raw model output against the extraction schema. + + Performs structural (JSON / Pydantic) validation followed by semantic + checks that catch hallucination indicators, cross-field inconsistencies, + and data-quality issues. + + Args: + raw_json: Either a JSON string or an already-parsed dict. + document_text: Optional original document text used for evidence + span verification. + + Returns: + A ``ValidationReport`` with parsed result on success. + """ + errors: list[str] = [] + warnings: list[str] = [] + + # --- Parse JSON string if needed --- + if isinstance(raw_json, str): + try: + data = json.loads(raw_json) + except json.JSONDecodeError as exc: + return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"]) + else: + data = raw_json + + if not isinstance(data, dict): + return ValidationReport(valid=False, errors=["Expected a JSON object at top level."]) + + # --- Pydantic structural validation --- + try: + result = ExtractionResult.model_validate(data) + except Exception as exc: # noqa: BLE001 + return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"]) + + # --- Semantic checks --- + sem_errors, sem_warnings = _semantic_checks(result, document_text) + errors.extend(sem_errors) + warnings.extend(sem_warnings) + + # Semantic errors make the report invalid — the caller should retry. + valid = len(errors) == 0 + return ValidationReport( + valid=valid, + errors=errors, + warnings=warnings, + parsed=result, + ) + + +# --------------------------------------------------------------------------- +# Known valid impact horizons +# --------------------------------------------------------------------------- + +VALID_IMPACT_HORIZONS = frozenset({ + "intraday", + "1d", + "1d_7d", + "1d_30d", + "30d_90d", + "90d_plus", +}) + +# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.) +_TICKER_RE = re.compile(r"^[A-Z]{1,5}$") + +# Evidence span length bounds (characters) +_MIN_EVIDENCE_LEN = 8 +_MAX_EVIDENCE_LEN = 500 + + +# --------------------------------------------------------------------------- +# Semantic validation rules +# --------------------------------------------------------------------------- + + +def _semantic_checks( + result: ExtractionResult, + document_text: str = "", +) -> tuple[list[str], list[str]]: + """Run semantic checks on a parsed extraction. + + Returns a tuple of (errors, warnings). Errors are issues severe enough + to warrant a retry; warnings are informational. + """ + errors: list[str] = [] + warnings: list[str] = [] + + # --- Top-level checks --- + if not result.summary: + warnings.append("empty_summary") + + if result.confidence < 0.3 and len(result.companies) > 0: + warnings.append("low_confidence_with_companies") + + # Duplicate tickers across company entries + tickers_seen: list[str] = [] + for comp in result.companies: + if comp.ticker in tickers_seen: + errors.append(f"duplicate_ticker_{comp.ticker}") + tickers_seen.append(comp.ticker) + + # --- Per-company checks --- + for comp in result.companies: + tag = comp.ticker or "unknown" + + # Ticker format + if not comp.ticker: + errors.append("company_missing_ticker") + elif not _TICKER_RE.match(comp.ticker): + warnings.append(f"invalid_ticker_format_{tag}") + + # Impact horizon must be a known value + if comp.impact_horizon not in VALID_IMPACT_HORIZONS: + errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}") + + # Evidence spans + if not comp.evidence_spans: + warnings.append(f"no_evidence_spans_for_{tag}") + else: + for idx, span in enumerate(comp.evidence_spans): + if len(span) < _MIN_EVIDENCE_LEN: + warnings.append(f"evidence_span_too_short_for_{tag}_{idx}") + if len(span) > _MAX_EVIDENCE_LEN: + warnings.append(f"evidence_span_too_long_for_{tag}_{idx}") + + # Cross-field: high impact but no facts + if not comp.key_facts and comp.impact_score > 0.5: + warnings.append(f"high_impact_no_facts_for_{tag}") + + # Cross-field: very low relevance + if comp.relevance < 0.2: + warnings.append(f"very_low_relevance_for_{tag}") + + # Cross-field: strong sentiment but low impact + if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1: + warnings.append(f"strong_sentiment_low_impact_for_{tag}") + + # --- Evidence grounding check (when source text is available) --- + if document_text: + doc_lower = document_text.lower() + for comp in result.companies: + for idx, span in enumerate(comp.evidence_spans): + if span.lower() not in doc_lower: + warnings.append( + f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}" + ) + + return errors, warnings + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]: + """Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained.""" + defs = schema.pop("$defs", {}) + return _resolve_refs(schema, defs) + + +def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any: + """Walk the schema tree and replace ``$ref`` pointers with their definitions.""" + if isinstance(node, dict): + if "$ref" in node: + ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem" + ref_name = ref_path.rsplit("/", 1)[-1] + if ref_name in defs: + resolved = defs[ref_name].copy() + # The resolved def may itself contain refs + return _resolve_refs(resolved, defs) + return node # unresolvable ref, leave as-is + return {k: _resolve_refs(v, defs) for k, v in node.items()} + if isinstance(node, list): + return [_resolve_refs(item, defs) for item in node] + return node diff --git a/services/extractor/worker.py b/services/extractor/worker.py index 7f47a87..d6afe69 100644 --- a/services/extractor/worker.py +++ b/services/extractor/worker.py @@ -1 +1,291 @@ -"""Extraction worker - sends documents to Ollama for structured intelligence extraction.""" +"""Extraction worker - sends documents to Ollama for structured intelligence extraction. + +Orchestrates the full extraction pipeline for a single document: +1. Calls OllamaClient to get structured extraction +2. Uploads prompts, raw outputs, and validation reports to MinIO +3. Persists the final intelligence object and per-company impact records to PostgreSQL +4. Updates document status + +Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2 +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timezone + +import asyncpg +from minio import Minio + +from services.extractor.client import ExtractionResponse +from services.extractor.metrics import collect_metrics, persist_metrics +from services.shared.metadata import ( + persist_document_impact, + persist_document_intelligence, + update_document_status, +) +from services.shared.storage import ( + upload_extraction_intelligence, + upload_extraction_prompt, + upload_extraction_raw_output, + upload_extraction_validation, +) +from services.shared.logging import Span +from services.shared.metrics import ( + EXTRACTION_ATTEMPTS, + EXTRACTION_CONFIDENCE, + EXTRACTION_DURATION, + EXTRACTION_JOBS_TOTAL, + EXTRACTION_RETRIES, + EXTRACTION_TOKEN_ESTIMATE, + EXTRACTION_VALIDATION_ERRORS, +) + +logger = logging.getLogger("extractor_worker") + + +@dataclass +class ExtractionPersistResult: + """Result of persisting an extraction to storage and database.""" + + intelligence_id: str | None = None + prompt_ref: str | None = None + raw_output_ref: str | None = None + validation_ref: str | None = None + intelligence_ref: str | None = None + impact_ids: list[str] | None = None + metrics_id: str | None = None + success: bool = False + + +async def persist_extraction( + *, + pool: asyncpg.Pool, + minio_client: Minio, + document_id: str, + ticker: str, + extraction_response: ExtractionResponse, + company_id_map: dict[str, str] | None = None, + source_credibility: float = 0.5, + timestamp: datetime | None = None, + document_text_length: int = 0, +) -> ExtractionPersistResult: + """Persist all extraction artifacts to MinIO and PostgreSQL. + + Uploads prompts, raw model outputs, validation reports, and the final + intelligence object to MinIO. Persists the intelligence record and + per-company impact records to PostgreSQL. Updates document status. + Also collects and persists model performance metrics. + + Args: + pool: PostgreSQL connection pool. + minio_client: MinIO client. + document_id: UUID of the source document. + ticker: Primary ticker for path construction. + extraction_response: Full response from OllamaClient.extract(). + company_id_map: Optional mapping of ticker -> company UUID for impact records. + source_credibility: Credibility score to attach to the intelligence record. + timestamp: Override timestamp for MinIO paths (defaults to UTC now). + document_text_length: Length of the input document text for token estimation. + + Returns: + ExtractionPersistResult with references to all persisted artifacts. + """ + ts = timestamp or datetime.now(timezone.utc) + result = ExtractionPersistResult() + company_id_map = company_id_map or {} + + # 1. Upload prompt metadata to MinIO + prompt_payload = json.dumps({ + "prompt_metadata": extraction_response.prompt_metadata, + "model": extraction_response.model, + }, indent=2).encode() + result.prompt_ref = upload_extraction_prompt( + minio_client, ticker, document_id, prompt_payload, timestamp=ts, + ) + + # 2. Upload raw outputs for each attempt + attempts_data: list[dict[str, object]] = [] + for idx, attempt in enumerate(extraction_response.attempts): + attempt_record: dict[str, object] = { + "attempt_index": idx, + "raw_output": attempt.raw_output, + "error": attempt.error, + "duration_ms": attempt.duration_ms, + "model": attempt.model, + "retryable": attempt.retryable, + } + if attempt.validation: + attempt_record["validation"] = { + "valid": attempt.validation.valid, + "errors": attempt.validation.errors, + "warnings": attempt.validation.warnings, + } + attempts_data.append(attempt_record) + + raw_output_payload = json.dumps({ + "document_id": document_id, + "attempts": attempts_data, + "total_duration_ms": extraction_response.total_duration_ms, + "success": extraction_response.success, + }, indent=2).encode() + result.raw_output_ref = upload_extraction_raw_output( + minio_client, ticker, document_id, raw_output_payload, timestamp=ts, + ) + + # 3. Upload validation report + final_attempt = extraction_response.attempts[-1] if extraction_response.attempts else None + validation_payload = json.dumps({ + "document_id": document_id, + "success": extraction_response.success, + "attempt_count": len(extraction_response.attempts), + "final_validation": { + "valid": final_attempt.validation.valid if final_attempt and final_attempt.validation else False, + "errors": final_attempt.validation.errors if final_attempt and final_attempt.validation else [], + "warnings": final_attempt.validation.warnings if final_attempt and final_attempt.validation else [], + } if final_attempt else None, + }, indent=2).encode() + result.validation_ref = upload_extraction_validation( + minio_client, ticker, document_id, validation_payload, timestamp=ts, + ) + + # 4. Determine validation status and persist intelligence + if extraction_response.success and extraction_response.result: + extraction = extraction_response.result + validation_status = "valid" + validation_errors: list[str] = [] + + # Upload final intelligence object to MinIO + intelligence_payload = json.dumps( + extraction.model_dump(mode="json"), indent=2, + ).encode() + result.intelligence_ref = upload_extraction_intelligence( + minio_client, ticker, document_id, intelligence_payload, timestamp=ts, + ) + + # Persist to PostgreSQL + intel_id = await persist_document_intelligence( + pool, + document_id=document_id, + summary=extraction.summary, + macro_themes=extraction.macro_themes, + novelty_score=extraction.novelty_score, + source_credibility=source_credibility, + extraction_warnings=extraction.extraction_warnings, + confidence=extraction.confidence, + model_provider="ollama", + model_name=extraction_response.model, + prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""), + schema_version=extraction_response.prompt_metadata.get("schema_version", ""), + raw_output_ref=result.raw_output_ref, + prompt_ref=result.prompt_ref, + validation_status=validation_status, + validation_errors=validation_errors, + retry_count=len(extraction_response.attempts) - 1, + ) + result.intelligence_id = intel_id + + # Persist per-company impact records + result.impact_ids = [] + for company in extraction.companies: + cid = company_id_map.get(company.ticker) + if not cid: + logger.warning( + "No company_id for ticker %s in doc %s, skipping impact record", + company.ticker, document_id, + ) + continue + impact_id = await persist_document_impact( + pool, + intelligence_id=intel_id, + company_id=cid, + ticker=company.ticker, + relevance=company.relevance, + sentiment=company.sentiment, + impact_score=company.impact_score, + impact_horizon=company.impact_horizon, + catalyst_type=company.catalyst_type, + key_facts=company.key_facts, + risks=company.risks, + evidence_spans=company.evidence_spans, + ) + result.impact_ids.append(impact_id) + + await update_document_status(pool, document_id=document_id, status="extracted") + result.success = True + logger.info( + "Extraction persisted for doc %s: intel=%s, impacts=%d", + document_id, intel_id, len(result.impact_ids), + ) + else: + # Failed extraction — still persist the attempt data + all_errors: list[str] = [] + for attempt in extraction_response.attempts: + if attempt.error: + all_errors.append(attempt.error) + + intel_id = await persist_document_intelligence( + pool, + document_id=document_id, + summary="", + macro_themes=[], + novelty_score=0.0, + source_credibility=source_credibility, + extraction_warnings=["extraction_failed"], + confidence=0.0, + model_provider="ollama", + model_name=extraction_response.model, + prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""), + schema_version=extraction_response.prompt_metadata.get("schema_version", ""), + raw_output_ref=result.raw_output_ref, + prompt_ref=result.prompt_ref, + validation_status="failed", + validation_errors=all_errors, + retry_count=len(extraction_response.attempts), + ) + result.intelligence_id = intel_id + + await update_document_status(pool, document_id=document_id, status="extraction_failed") + logger.warning( + "Extraction failed for doc %s after %d attempts: %s", + document_id, len(extraction_response.attempts), "; ".join(all_errors), + ) + + # Collect and persist model performance metrics + try: + metrics = collect_metrics( + extraction_response, + document_id=document_id, + ticker=ticker, + document_text_length=document_text_length, + ) + metrics.recorded_at = ts + metrics_id = await persist_metrics(pool, metrics) + result.metrics_id = metrics_id + except Exception: + logger.exception("Failed to persist extraction metrics for doc %s", document_id) + + # Prometheus metrics + EXTRACTION_ATTEMPTS.inc(len(extraction_response.attempts)) + EXTRACTION_DURATION.observe(extraction_response.total_duration_ms / 1000.0) + retry_count = max(0, len(extraction_response.attempts) - 1) + if retry_count > 0: + EXTRACTION_RETRIES.inc(retry_count) + if extraction_response.success: + EXTRACTION_JOBS_TOTAL.labels(status="success").inc() + if extraction_response.result: + EXTRACTION_CONFIDENCE.observe(extraction_response.result.confidence) + else: + EXTRACTION_JOBS_TOTAL.labels(status="failed").inc() + # Count validation errors from final attempt + final = extraction_response.attempts[-1] if extraction_response.attempts else None + if final and final.validation and final.validation.errors: + EXTRACTION_VALIDATION_ERRORS.inc(len(final.validation.errors)) + # Token estimates + if document_text_length > 0: + EXTRACTION_TOKEN_ESTIMATE.labels(direction="input").inc(document_text_length // 4) + if final and final.raw_output: + EXTRACTION_TOKEN_ESTIMATE.labels(direction="output").inc(len(final.raw_output) // 4) + + return result diff --git a/services/ingestion/worker.py b/services/ingestion/worker.py index 331dd01..b1b8728 100644 --- a/services/ingestion/worker.py +++ b/services/ingestion/worker.py @@ -1,47 +1,50 @@ """Ingestion worker - processes jobs from the ingestion queue.""" import asyncio -import hashlib -import io import json import logging -from datetime import datetime import asyncpg import redis.asyncio as aioredis from minio import Minio from services.adapters.base import AdapterResult -from services.adapters.filings_adapter import FilingsAdapter -from services.adapters.market_adapter import MarketDataAdapter -from services.adapters.news_adapter import NewsApiAdapter +from services.adapters.broker_adapter import AlpacaBrokerAdapter, TradingMode +from services.adapters.filings_adapter import SECEdgarAdapter +from services.adapters.market_adapter import PolygonMarketAdapter +from services.adapters.news_adapter import PolygonNewsAdapter +from services.adapters.web_scrape_adapter import WebScrapeAdapter from services.shared.config import load_config from services.shared.db import get_minio, get_pg_pool, get_redis +from services.shared.dedupe import dedupe_items, mark_as_seen +from services.shared.metadata import ( + persist_ingestion_items, + record_retrieval_failure, + reset_source_retry_state, +) from services.shared.redis_keys import ( QUEUE_INGESTION, QUEUE_PARSING, dedupe_key, queue_key, ) +from services.shared.logging import Span, extract_trace_context, inject_trace_context, new_trace_id, set_trace_context, setup_logging +from services.shared.metrics import ( + ACTIVE_JOBS, + INGESTION_ADAPTER_DURATION, + INGESTION_ERRORS, + INGESTION_ITEMS_DEDUPED, + INGESTION_ITEMS_FETCHED, + INGESTION_ITEMS_NEW, + INGESTION_JOBS_TOTAL, +) +from services.shared.storage import ( + bucket_for_source, + ensure_buckets, + upload_raw_artifact, +) -logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ingestion_worker") -BUCKET_MAP = { - "market_api": "stonks-raw-market", - "news_api": "stonks-raw-news", - "filings_api": "stonks-raw-filings", - "broker": "stonks-raw-market", -} - - -def build_storage_path(source_type: str, ticker: str, doc_id: str) -> str: - now = datetime.utcnow() - return f"{source_type}/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/raw.json" - - -async def store_raw_artifact(minio_client: Minio, bucket: str, path: str, data: bytes): - minio_client.put_object(bucket, path, io.BytesIO(data), len(data), content_type="application/json") - async def process_job( job: dict, @@ -55,9 +58,11 @@ async def process_job( source_id = job["source_id"] config = job.get("config", {}) + set_trace_context(trace_id=job.get("_trace_id") or new_trace_id()) + adapter = adapters.get(source_type) if not adapter: - logger.warning(f"No adapter for source_type={source_type}") + logger.warning("No adapter for source_type=%s", source_type) return # Record ingestion run @@ -68,25 +73,37 @@ async def process_job( ) try: - result: AdapterResult = await adapter.fetch(ticker, config) + with Span("adapter_fetch", ticker=ticker, source_type=source_type): + with INGESTION_ADAPTER_DURATION.labels(source_type=source_type).time(): + result: AdapterResult = await adapter.fetch(ticker, config) if result.error: - await pool.execute( - "UPDATE ingestion_runs SET status='failed', error_message=$2, completed_at=NOW() WHERE id=$1", - run_id, result.error, + INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="error").inc() + await record_retrieval_failure( + pool, + run_id=str(run_id), + source_id=source_id, + error_message=result.error, ) return - # Store raw payload - bucket = BUCKET_MAP.get(source_type, "stonks-raw-market") - storage_path = build_storage_path(source_type, ticker, str(run_id)) - await store_raw_artifact(minio_client, bucket, storage_path, result.raw_payload) + # Store raw payload in MinIO + bucket = bucket_for_source(source_type) + artifact_type = "raw_html" if source_type == "web_scrape" else "raw_json" + storage_uri = upload_raw_artifact( + minio_client, + source_type=source_type, + ticker=ticker, + document_id=str(run_id), + data=result.raw_payload, + artifact_type=artifact_type, + ) - # Dedupe check + # Dedupe check on the overall payload hash if result.content_hash: already_seen = await rds.get(dedupe_key(result.content_hash)) if already_seen: - logger.info(f"Duplicate content for {ticker}, skipping") + logger.info("Duplicate content for %s, skipping", ticker) await pool.execute( "UPDATE ingestion_runs SET status='completed', items_fetched=$2, items_new=0, completed_at=NOW() WHERE id=$1", run_id, len(result.items), @@ -94,72 +111,126 @@ async def process_job( return await rds.set(dedupe_key(result.content_hash), "1", ex=86400) - new_items = 0 - for item in result.items: - item_json = json.dumps(item) - item_hash = hashlib.sha256(item_json.encode()).hexdigest() + # Cross-source dedupe on individual document items (news, filings, web_scrape) + items_to_persist = result.items + deduped_count = 0 + if source_type not in ("market_api", "broker"): + items_to_persist, dup_items = await dedupe_items(pool, rds, result.items) + deduped_count = len(dup_items) + if deduped_count: + INGESTION_ITEMS_DEDUPED.labels(source_type=source_type).inc(deduped_count) + logger.info( + "Deduped %d/%d items for %s/%s", + deduped_count, len(result.items), ticker, source_type, + ) - # Check if document already exists - exists = await pool.fetchval("SELECT 1 FROM documents WHERE content_hash = $1", item_hash) - if exists: - continue + # Persist metadata via the unified metadata module + new_items, new_ids = await persist_ingestion_items( + pool, + source_type=source_type, + ticker=ticker, + company_id=job.get("company_id"), + items=items_to_persist, + storage_ref=storage_uri, + adapter_metadata=result.metadata, + content_hash=result.content_hash, + ) - title = item.get("title", item.get("name", "")) - url = item.get("url", item.get("link", "")) - published = item.get("publishedAt", item.get("published_at")) + # Enqueue new document items for parsing (not market/broker) + if source_type not in ("market_api", "broker"): + for doc_id in new_ids: + await rds.rpush(queue_key(QUEUE_PARSING), json.dumps(inject_trace_context({ + "document_id": doc_id, + "ticker": ticker, + "source_type": source_type, + }))) - doc_id = await pool.fetchval( - """INSERT INTO documents (document_type, source_type, publisher, url, title, published_at, content_hash, raw_storage_ref, status) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'ingested') - RETURNING id""", - "article" if source_type == "news_api" else "filing" if source_type == "filings_api" else "article", - source_type, - item.get("source", {}).get("name", "") if isinstance(item.get("source"), dict) else str(item.get("source", "")), - url, title, - datetime.fromisoformat(published.replace("Z", "+00:00")) if published else None, - item_hash, - f"s3://{bucket}/{storage_path}", - ) + # Mark newly persisted documents in Redis for fast future dedupe + for item, doc_id in zip(items_to_persist, new_ids): + await mark_as_seen( + rds, + content_hash=item.get("content_hash", ""), + canonical_url=item.get("canonical_url"), + document_id=doc_id, + ) - # Enqueue for parsing - await rds.rpush(queue_key(QUEUE_PARSING), json.dumps({ - "document_id": str(doc_id), - "ticker": ticker, - "source_type": source_type, - "url": url, - })) - new_items += 1 + # Link duplicate documents to this company if not already linked + company_id = job.get("company_id") + if company_id and deduped_count: + from services.shared.metadata import persist_document_company_mention + for dup in dup_items: + existing_id = dup.get("_dedupe_existing_id") + if existing_id: + try: + await persist_document_company_mention( + pool, + document_id=existing_id, + company_id=company_id, + ticker=ticker, + mention_type="cross_source", + ) + except Exception: + # Duplicate mention link — safe to ignore + pass await pool.execute( "UPDATE ingestion_runs SET status='completed', items_fetched=$2, items_new=$3, completed_at=NOW() WHERE id=$1", run_id, len(result.items), new_items, ) - logger.info(f"Ingested {ticker}/{source_type}: {len(result.items)} fetched, {new_items} new") + # Clear any accumulated retry backoff after success + await reset_source_retry_state(pool, source_id) + INGESTION_ITEMS_FETCHED.labels(source_type=source_type).inc(len(result.items)) + INGESTION_ITEMS_NEW.labels(source_type=source_type).inc(new_items) + INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="success").inc() + logger.info( + "Ingested %s/%s: %d fetched, %d new", + ticker, source_type, len(result.items), new_items, + extra={"ticker": ticker, "source_type": source_type, "count": new_items}, + ) except Exception as e: - logger.error(f"Ingestion error for {ticker}: {e}") - await pool.execute( - "UPDATE ingestion_runs SET status='failed', error_message=$2, completed_at=NOW() WHERE id=$1", - run_id, str(e), + INGESTION_ERRORS.labels(source_type=source_type).inc() + INGESTION_JOBS_TOTAL.labels(source_type=source_type, status="error").inc() + logger.error( + "Ingestion error for %s: %s", ticker, e, + extra={"ticker": ticker, "source_type": source_type, "error": str(e)}, + ) + await record_retrieval_failure( + pool, + run_id=str(run_id), + source_id=source_id, + error_message=str(e), ) async def main(): - config = load_config() - pool = await get_pg_pool(config) - rds = get_redis(config) - minio_client = get_minio(config) + cfg = load_config() + setup_logging("ingestion_worker", level=cfg.log_level, json_output=cfg.json_logs) + + pool = await get_pg_pool(cfg) + rds = get_redis(cfg) + minio_client = get_minio(cfg) + + # Ensure all required buckets exist + ensure_buckets(minio_client) adapters = { - "market_api": MarketDataAdapter( - api_key=config.broker.api_key or "", + "market_api": PolygonMarketAdapter( + api_key=cfg.market_data.api_key, + base_url=cfg.market_data.base_url, + ), + "news_api": PolygonNewsAdapter( + api_key=cfg.market_data.api_key, base_url="https://api.polygon.io", ), - "news_api": NewsApiAdapter( - api_key="", - base_url="https://newsapi.org", + "filings_api": SECEdgarAdapter(), + "web_scrape": WebScrapeAdapter(), + "broker": AlpacaBrokerAdapter( + api_key=cfg.broker.api_key or "", + api_secret=cfg.broker.api_secret or "", + mode=TradingMode.LIVE if cfg.broker.mode == "live" else TradingMode.PAPER, + base_url=cfg.broker.base_url, ), - "filings_api": FilingsAdapter(), } logger.info("Ingestion worker started") diff --git a/services/lake_publisher/__init__.py b/services/lake_publisher/__init__.py index de6226f..c2dfbc4 100644 --- a/services/lake_publisher/__init__.py +++ b/services/lake_publisher/__init__.py @@ -1 +1 @@ -# Lake Publisher - transforms operational data into analytical fact datasets +"""Lake publisher — writes partitioned Parquet facts to MinIO for Trino/Superset.""" diff --git a/services/lake_publisher/enqueue.py b/services/lake_publisher/enqueue.py new file mode 100644 index 0000000..c9282ea --- /dev/null +++ b/services/lake_publisher/enqueue.py @@ -0,0 +1,39 @@ +"""Helpers for enqueuing lake publish jobs from upstream workers. + +Other services import these helpers to push jobs onto the QUEUE_LAKE_PUBLISH +Redis queue. The lake publisher worker (jobs.py) consumes them. + +Usage: + await enqueue_lake_job(rds, "document", document_id) + await enqueue_lake_job(rds, "trade_order", order_id) + await enqueue_lake_job(rds, "bulk_documents", since=cutoff.isoformat()) +""" +from __future__ import annotations + +import json + +import redis.asyncio as aioredis + +from services.shared.redis_keys import QUEUE_LAKE_PUBLISH, queue_key + + +async def enqueue_lake_job( + rds: aioredis.Redis, + job_type: str, + entity_id: str = "", + since: str | None = None, +) -> None: + """Push a lake publish job onto the Redis queue. + + Args: + rds: Async Redis client. + job_type: One of the supported job types (document, document_extraction, + market_snapshot, trade_order, trade_fill, positions_snapshot, + pnl_snapshot, bulk_documents, bulk_extractions). + entity_id: UUID or identifier for the entity to publish. + since: ISO datetime string for bulk jobs (cutoff timestamp). + """ + payload: dict[str, str] = {"job_type": job_type, "entity_id": entity_id} + if since: + payload["since"] = since + await rds.rpush(queue_key(QUEUE_LAKE_PUBLISH), json.dumps(payload)) # type: ignore[misc] diff --git a/services/lake_publisher/iceberg.py b/services/lake_publisher/iceberg.py new file mode 100644 index 0000000..6b4a3d8 --- /dev/null +++ b/services/lake_publisher/iceberg.py @@ -0,0 +1,420 @@ +"""Iceberg table creation and metadata management for analytical datasets. + +Manages Iceberg tables in Trino's Iceberg catalog, providing: +- Table creation with proper schemas and partition specs +- Schema synchronization between PyArrow definitions and Iceberg tables +- Table metadata inspection (existence checks, schema retrieval, partition listing) + +The Iceberg catalog complements the existing Hive-compatible partition layout. +Parquet files written by the lake publisher are stored in the same MinIO paths, +but Iceberg metadata enables schema evolution, snapshot isolation, and better +partition pruning via Trino's Iceberg connector. + +Requirements: 9.4, 9.5, 10.1, N4, N6 +Design ref: Section 5.3 (Lakehouse model), Section 4.12 (SQL Query Engine) +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +import pyarrow as pa +from trino.dbapi import connect as trino_connect + +from services.lake_publisher.partitions import ( + LAKEHOUSE_BUCKET, + TABLE_PARTITIONS, + WAREHOUSE_PREFIX, + PartitionSpec, +) +from services.lake_publisher.worker import ( + COMPANY_EVENTS_SCHEMA, + DOCUMENTS_SCHEMA, + DOCUMENT_EXTRACTIONS_SCHEMA, + MARKET_BARS_SCHEMA, + MARKET_QUOTES_SCHEMA, + MODEL_PERFORMANCE_SCHEMA, + PNL_DAILY_SCHEMA, + POSITIONS_DAILY_SCHEMA, + PREDICTION_VS_OUTCOME_SCHEMA, + TRADE_FILLS_SCHEMA, + TRADE_ORDERS_SCHEMA, + TRADE_SIGNALS_SCHEMA, +) + +logger = logging.getLogger(__name__) + +ICEBERG_CATALOG = "iceberg" +ICEBERG_SCHEMA = "stonks" + + +def _get_iceberg_catalog() -> str: + """Return the Iceberg catalog name from env or default.""" + import os + return os.getenv("TRINO_ICEBERG_CATALOG", ICEBERG_CATALOG) + +# Map PyArrow types to Trino/Iceberg SQL types. +_ARROW_TO_TRINO: dict[str, str] = { + "string": "VARCHAR", + "utf8": "VARCHAR", + "large_string": "VARCHAR", + "large_utf8": "VARCHAR", + "float64": "DOUBLE", + "double": "DOUBLE", + "float32": "REAL", + "float": "REAL", + "int8": "TINYINT", + "int16": "SMALLINT", + "int32": "INTEGER", + "int64": "BIGINT", + "bool": "BOOLEAN", + "date32": "DATE", + "date32[day]": "DATE", + "date64": "DATE", +} + + +def _arrow_type_to_trino(arrow_type: pa.DataType) -> str: + """Convert a PyArrow data type to a Trino SQL type string.""" + type_str = str(arrow_type) + + # Handle timestamp types (with or without timezone) + if type_str.startswith("timestamp"): + if "tz=" in type_str: + return "TIMESTAMP(6) WITH TIME ZONE" + return "TIMESTAMP(6)" + + # Direct lookup + result = _ARROW_TO_TRINO.get(type_str) + if result: + return result + + # Fallback for type IDs + if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + return "VARCHAR" + if pa.types.is_floating(arrow_type): + return "DOUBLE" + if pa.types.is_integer(arrow_type): + return "BIGINT" + if pa.types.is_boolean(arrow_type): + return "BOOLEAN" + if pa.types.is_date(arrow_type): + return "DATE" + if pa.types.is_timestamp(arrow_type): + return "TIMESTAMP(6) WITH TIME ZONE" + + raise ValueError(f"Unsupported PyArrow type for Iceberg DDL: {arrow_type}") + + + +# Registry mapping table names to their PyArrow schemas. +TABLE_SCHEMAS: dict[str, pa.Schema] = { + "market_bars": MARKET_BARS_SCHEMA, + "market_quotes": MARKET_QUOTES_SCHEMA, + "company_events": COMPANY_EVENTS_SCHEMA, + "documents": DOCUMENTS_SCHEMA, + "document_extractions": DOCUMENT_EXTRACTIONS_SCHEMA, + "trade_signals": TRADE_SIGNALS_SCHEMA, + "trade_orders": TRADE_ORDERS_SCHEMA, + "trade_fills": TRADE_FILLS_SCHEMA, + "positions_daily": POSITIONS_DAILY_SCHEMA, + "pnl_daily": PNL_DAILY_SCHEMA, + "prediction_vs_outcome": PREDICTION_VS_OUTCOME_SCHEMA, + "model_performance": MODEL_PERFORMANCE_SCHEMA, +} + + +@dataclass(frozen=True) +class IcebergTableDef: + """Definition for an Iceberg table derived from PyArrow schema + partition spec.""" + + table_name: str + schema: pa.Schema + partition_spec: PartitionSpec + + @property + def qualified_name(self) -> str: + return f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.{self.table_name}" + + @property + def location(self) -> str: + return f"s3a://{LAKEHOUSE_BUCKET}/{WAREHOUSE_PREFIX}/{self.table_name}/" + + def column_defs_sql(self) -> list[str]: + """Generate SQL column definitions from the PyArrow schema. + + Partition columns are included in the column list (Iceberg stores them + in the data files, unlike Hive external tables). + """ + cols: list[str] = [] + for i in range(len(self.schema)): + name = self.schema.field(i).name + arrow_type = self.schema.field(i).type + trino_type = _arrow_type_to_trino(arrow_type) + cols.append(f" {name} {trino_type}") + return cols + + def partition_keys_sql(self) -> str: + """Generate the partitioning clause for CREATE TABLE.""" + keys = list(self.partition_spec.all_keys) + if not keys: + return "" + quoted = ", ".join(f"'{k}'" for k in keys) + return f"partitioning = ARRAY[{quoted}]" + + def create_table_sql(self) -> str: + """Generate a CREATE TABLE IF NOT EXISTS statement for Trino's Iceberg catalog.""" + col_lines = ",\n".join(self.column_defs_sql()) + with_clauses = [ + "format = 'PARQUET'", + f"location = '{self.location}'", + ] + part_sql = self.partition_keys_sql() + if part_sql: + with_clauses.append(part_sql) + + with_block = ",\n ".join(with_clauses) + + return ( + f"CREATE TABLE IF NOT EXISTS {self.qualified_name} (\n" + f"{col_lines}\n" + f") WITH (\n" + f" {with_block}\n" + f")" + ) + + +def get_all_table_defs() -> list[IcebergTableDef]: + """Build IcebergTableDef for every registered analytical table.""" + defs: list[IcebergTableDef] = [] + for table_name, partition_spec in TABLE_PARTITIONS.items(): + schema = TABLE_SCHEMAS.get(table_name) + if schema is None: + logger.warning("No PyArrow schema for table %s, skipping", table_name) + continue + defs.append(IcebergTableDef( + table_name=table_name, + schema=schema, + partition_spec=partition_spec, + )) + return defs + + +def get_table_def(table_name: str) -> IcebergTableDef: + """Get the IcebergTableDef for a single table by name.""" + if table_name not in TABLE_PARTITIONS: + raise ValueError(f"Unknown table: {table_name}") + schema = TABLE_SCHEMAS.get(table_name) + if schema is None: + raise ValueError(f"No PyArrow schema registered for table: {table_name}") + return IcebergTableDef( + table_name=table_name, + schema=schema, + partition_spec=TABLE_PARTITIONS[table_name], + ) + + + +@dataclass +class IcebergManager: + """Manages Iceberg tables via Trino's Iceberg catalog. + + Provides table creation, existence checks, schema inspection, + and metadata operations against the Trino Iceberg connector. + """ + + host: str = "localhost" + port: int = 8080 + user: str = "stonks" + catalog: str = ICEBERG_CATALOG + schema: str = ICEBERG_SCHEMA + + def _get_connection(self) -> Any: + """Create a Trino DBAPI connection.""" + return trino_connect( + host=self.host, + port=self.port, + user=self.user, + catalog=self.catalog, + schema=self.schema, + ) + + def _execute(self, sql: str) -> list[list[Any]]: + """Execute a SQL statement and return all rows.""" + conn = self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql) + return cursor.fetchall() + finally: + conn.close() + + def _execute_no_fetch(self, sql: str) -> None: + """Execute a DDL statement that returns no rows.""" + conn = self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(sql) + # DDL statements in Trino still need fetchall to complete + try: + cursor.fetchall() + except Exception: + pass + finally: + conn.close() + + def ensure_schema(self) -> None: + """Create the Iceberg schema if it doesn't exist.""" + sql = f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.schema}" + logger.info("Ensuring Iceberg schema: %s.%s", self.catalog, self.schema) + self._execute_no_fetch(sql) + + def table_exists(self, table_name: str) -> bool: + """Check if an Iceberg table exists.""" + sql = ( + f"SELECT table_name FROM {self.catalog}.information_schema.tables " + f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}'" + ) + rows = self._execute(sql) + return len(rows) > 0 + + def create_table(self, table_name: str) -> bool: + """Create a single Iceberg table if it doesn't exist. + + Returns True if the table was created, False if it already existed. + """ + table_def = get_table_def(table_name) + ddl = table_def.create_table_sql() + logger.info("Creating Iceberg table: %s", table_def.qualified_name) + self._execute_no_fetch(ddl) + logger.info("Iceberg table ready: %s", table_def.qualified_name) + return True + + def create_all_tables(self) -> dict[str, bool]: + """Create all registered Iceberg tables. + + Returns a dict mapping table_name -> True (created) or False (error). + """ + self.ensure_schema() + results: dict[str, bool] = {} + for table_def in get_all_table_defs(): + try: + self.create_table(table_def.table_name) + results[table_def.table_name] = True + except Exception: + logger.exception("Failed to create Iceberg table: %s", table_def.table_name) + results[table_def.table_name] = False + return results + + def get_table_schema(self, table_name: str) -> list[dict[str, str]]: + """Retrieve the column schema of an Iceberg table from Trino. + + Returns a list of dicts with 'column_name', 'data_type', and 'is_nullable'. + """ + sql = ( + f"SELECT column_name, data_type, is_nullable " + f"FROM {self.catalog}.information_schema.columns " + f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}' " + f"ORDER BY ordinal_position" + ) + rows = self._execute(sql) + return [ + {"column_name": r[0], "data_type": r[1], "is_nullable": r[2]} + for r in rows + ] + + def get_table_snapshots(self, table_name: str) -> list[dict[str, Any]]: + """List Iceberg snapshots for a table (useful for auditing and rollback). + + Returns snapshot metadata from Trino's $snapshots metadata table. + """ + qualified = f"{self.catalog}.{self.schema}.{table_name}" + sql = f'SELECT * FROM "{qualified}$snapshots"' + try: + rows = self._execute(sql) + return [{"snapshot_id": r[0], "parent_id": r[1], "operation": r[2], + "manifest_list": r[3], "summary": r[4]} for r in rows] + except Exception: + logger.debug("Could not read snapshots for %s (table may be empty)", table_name) + return [] + + def get_table_partitions(self, table_name: str) -> list[dict[str, Any]]: + """List partition values for an Iceberg table. + + Returns partition metadata from Trino's $partitions metadata table. + """ + qualified = f"{self.catalog}.{self.schema}.{table_name}" + sql = f'SELECT * FROM "{qualified}$partitions"' + try: + rows = self._execute(sql) + return [{"row": r} for r in rows] + except Exception: + logger.debug("Could not read partitions for %s (table may be empty)", table_name) + return [] + + def list_tables(self) -> list[str]: + """List all tables in the Iceberg schema.""" + sql = ( + f"SELECT table_name FROM {self.catalog}.information_schema.tables " + f"WHERE table_schema = '{self.schema}' ORDER BY table_name" + ) + rows = self._execute(sql) + return [r[0] for r in rows] + + def drop_table(self, table_name: str) -> None: + """Drop an Iceberg table (for testing/reset purposes).""" + qualified = f"{self.catalog}.{self.schema}.{table_name}" + logger.warning("Dropping Iceberg table: %s", qualified) + self._execute_no_fetch(f"DROP TABLE IF EXISTS {qualified}") + + def sync_table_schema(self, table_name: str) -> list[str]: + """Compare the expected PyArrow schema with the actual Iceberg table schema. + + If columns are missing from the Iceberg table, adds them via ALTER TABLE. + Returns a list of columns that were added. + + This supports forward-only schema evolution — columns are never dropped. + """ + table_def = get_table_def(table_name) + existing = self.get_table_schema(table_name) + existing_names = {col["column_name"] for col in existing} + + added: list[str] = [] + qualified = table_def.qualified_name + + for i in range(len(table_def.schema)): + col_name = table_def.schema.field(i).name + if col_name not in existing_names: + trino_type = _arrow_type_to_trino(table_def.schema.field(i).type) + alter_sql = f"ALTER TABLE {qualified} ADD COLUMN {col_name} {trino_type}" + logger.info("Adding column %s to %s", col_name, qualified) + self._execute_no_fetch(alter_sql) + added.append(col_name) + + return added + + def sync_all_schemas(self) -> dict[str, list[str]]: + """Sync schemas for all registered tables. Returns table_name -> added columns.""" + results: dict[str, list[str]] = {} + for table_def in get_all_table_defs(): + try: + if self.table_exists(table_def.table_name): + added = self.sync_table_schema(table_def.table_name) + results[table_def.table_name] = added + else: + logger.info("Table %s doesn't exist yet, skipping sync", table_def.table_name) + results[table_def.table_name] = [] + except Exception: + logger.exception("Failed to sync schema for %s", table_def.table_name) + results[table_def.table_name] = [] + return results + + +def create_iceberg_manager_from_config( + host: str = "localhost", + port: int = 8080, + user: str = "stonks", +) -> IcebergManager: + """Factory that creates an IcebergManager from explicit connection params.""" + return IcebergManager(host=host, port=port, user=user) diff --git a/services/lake_publisher/jobs.py b/services/lake_publisher/jobs.py new file mode 100644 index 0000000..5c1c6c0 --- /dev/null +++ b/services/lake_publisher/jobs.py @@ -0,0 +1,673 @@ +"""Lake publisher async job runner — transforms operational data into analytical facts. + +Reads jobs from the QUEUE_LAKE_PUBLISH Redis queue, queries PostgreSQL for +operational records, and publishes them as partitioned Parquet files to MinIO +via the existing publish_* functions in worker.py. + +Job message format: + {"job_type": "<table_name>", "entity_id": "<uuid or ticker>", "dt": "2026-04-11T..."} + +Supported job types: + - document: publish a single document metadata fact + - document_extraction: publish extraction facts for a document + - market_snapshot: publish market bars/quotes from a snapshot + - trade_order: publish an order fact + - trade_fill: publish fill facts for an order + - positions_snapshot: publish daily position snapshots for a broker account + - pnl_snapshot: publish daily PnL for a broker account + - company_event: publish a company event fact + - bulk_documents: publish all unpublished documents since a cutoff + - bulk_extractions: publish all unpublished extractions since a cutoff + +Requirements: 9.4, 9.5, 10.1 +Design ref: Section 4.10 (Lake Publisher), Section 8.4 (Lake publication flow) +""" +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timezone + +import asyncpg +import redis.asyncio as aioredis +from minio import Minio + +from services.lake_publisher.worker import ( + publish_document_extraction, + publish_document_fact, + publish_market_bar, + publish_market_quote, + publish_trade_order, + publish_trade_fill, + publish_pnl_daily, + publish_documents_batch, + publish_document_extractions_batch, + publish_positions_daily_batch, +) +from services.lake_publisher.partitions import partition_values +from services.shared.config import load_config +from services.shared.db import get_minio, get_pg_pool, get_redis +from services.shared.logging import setup_logging +from services.shared.redis_keys import QUEUE_LAKE_PUBLISH, queue_key + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# SQL queries for fetching operational data +# --------------------------------------------------------------------------- + +_FETCH_DOCUMENT = """ +SELECT + d.id, d.document_type, d.source_type, d.publisher, d.title, + d.url, d.canonical_url, d.language, d.published_at, d.retrieved_at, + d.content_hash, d.parse_quality_score, + COALESCE( + (SELECT dcm.ticker FROM document_company_mentions dcm + WHERE dcm.document_id = d.id LIMIT 1), + '' + ) AS ticker +FROM documents d +WHERE d.id = $1::uuid +""" + +_FETCH_EXTRACTIONS = """ +SELECT + di.document_id, dir.ticker, dir.relevance, dir.sentiment, + dir.impact_score, dir.impact_horizon, dir.catalyst_type, + di.confidence, di.novelty_score, di.source_credibility, + dir.key_facts, dir.risks, di.macro_themes, + di.model_name, di.prompt_version, di.schema_version, + di.created_at AS extraction_at, + COALESCE(c.legal_name, '') AS company_name +FROM document_intelligence di +JOIN document_impact_records dir ON dir.intelligence_id = di.id +LEFT JOIN companies c ON c.id = dir.company_id +WHERE di.document_id = $1::uuid + AND di.validation_status = 'valid' +""" + +_FETCH_MARKET_SNAPSHOT = """ +SELECT + ms.ticker, ms.snapshot_type, ms.data, ms.source_provider, ms.captured_at +FROM market_snapshots ms +WHERE ms.id = $1::uuid +""" + +_FETCH_ORDER = """ +SELECT + o.id, o.recommendation_id, o.ticker, o.side, o.order_type, + o.quantity, o.limit_price, o.status, o.submitted_at, + o.fill_price, o.fill_quantity, o.filled_at, + COALESCE(ba.account_id, '') AS broker_account, + COALESCE(ba.mode, 'paper') AS execution_mode +FROM orders o +LEFT JOIN broker_accounts ba ON ba.id = o.broker_account_id +WHERE o.id = $1::uuid +""" + +_FETCH_ORDER_FILLS = """ +SELECT + oe.id AS fill_id, oe.order_id, oe.data, oe.broker_timestamp, + o.ticker, o.side, + COALESCE(ba.account_id, '') AS broker_account +FROM order_events oe +JOIN orders o ON o.id = oe.order_id +LEFT JOIN broker_accounts ba ON ba.id = o.broker_account_id +WHERE oe.order_id = $1::uuid AND oe.event_type = 'fill' +""" + +_FETCH_POSITIONS = """ +SELECT + p.ticker, p.quantity, p.avg_entry_price, p.current_price, + p.unrealized_pnl, p.realized_pnl, + COALESCE(ba.account_id, '') AS broker_account, + COALESCE(ba.mode, 'paper') AS execution_mode +FROM positions p +LEFT JOIN broker_accounts ba ON ba.id = p.broker_account_id +WHERE p.broker_account_id = $1::uuid AND p.quantity != 0 +""" + +_FETCH_BULK_DOCUMENTS = """ +SELECT + d.id, d.document_type, d.source_type, d.publisher, d.title, + d.url, d.canonical_url, d.language, d.published_at, d.retrieved_at, + d.content_hash, d.parse_quality_score, + COALESCE( + (SELECT dcm.ticker FROM document_company_mentions dcm + WHERE dcm.document_id = d.id LIMIT 1), + '' + ) AS ticker +FROM documents d +WHERE d.created_at >= $1 + AND d.status IN ('parsed', 'extracted') +ORDER BY d.created_at +LIMIT 500 +""" + +_FETCH_BULK_EXTRACTIONS = """ +SELECT + di.document_id, dir.ticker, dir.relevance, dir.sentiment, + dir.impact_score, dir.impact_horizon, dir.catalyst_type, + di.confidence, di.novelty_score, di.source_credibility, + dir.key_facts, dir.risks, di.macro_themes, + di.model_name, di.prompt_version, di.schema_version, + di.created_at AS extraction_at, + COALESCE(c.legal_name, '') AS company_name +FROM document_intelligence di +JOIN document_impact_records dir ON dir.intelligence_id = di.id +LEFT JOIN companies c ON c.id = dir.company_id +WHERE di.created_at >= $1 + AND di.validation_status = 'valid' +ORDER BY di.created_at +LIMIT 500 +""" + + +# --------------------------------------------------------------------------- +# Job handlers — each transforms operational rows into lake facts +# --------------------------------------------------------------------------- + + +def _jsonb_to_str(val: object) -> str: + """Convert a JSONB column value (list or str) to a comma-separated string.""" + if val is None: + return "" + if isinstance(val, str): + try: + parsed = json.loads(val) + if isinstance(parsed, list): + return ", ".join(str(x) for x in parsed) + return val + except (json.JSONDecodeError, TypeError): + return val + if isinstance(val, list): + return ", ".join(str(x) for x in val) + return str(val) + + +async def publish_document_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> str: + """Publish a single document metadata fact from PostgreSQL to the lake.""" + row = await pool.fetchrow(_FETCH_DOCUMENT, entity_id) + if row is None: + logger.warning("Document %s not found, skipping lake publish", entity_id) + return "" + + published_at = row["published_at"] or row["retrieved_at"] + return publish_document_fact( + client=minio_client, + document_id=str(row["id"]), + document_type=row["document_type"], + source_type=row["source_type"], + ticker=row["ticker"] or "", + publisher=row["publisher"] or "", + title=row["title"] or "", + published_at=published_at, + content_hash=row["content_hash"], + url=row["url"] or "", + canonical_url=row["canonical_url"] or "", + language=row["language"] or "en", + confidence=float(row["parse_quality_score"] or 0.0), + retrieved_at=row["retrieved_at"], + ) + + +async def publish_extraction_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> list[str]: + """Publish document extraction facts for a document from PostgreSQL to the lake.""" + rows = await pool.fetch(_FETCH_EXTRACTIONS, entity_id) + if not rows: + logger.info("No valid extractions for document %s", entity_id) + return [] + + refs: list[str] = [] + for row in rows: + ref = publish_document_extraction( + client=minio_client, + document_id=str(row["document_id"]), + ticker=row["ticker"], + sentiment=row["sentiment"] or "neutral", + impact_score=float(row["impact_score"] or 0.0), + catalyst_type=row["catalyst_type"] or "other", + confidence=float(row["confidence"] or 0.0), + extraction_at=row["extraction_at"], + model_name=row["model_name"] or "", + prompt_version=row["prompt_version"] or "", + company_name=row["company_name"] or "", + relevance=float(row["relevance"] or 0.0), + impact_horizon=row["impact_horizon"] or "", + novelty_score=float(row["novelty_score"] or 0.0), + source_credibility=float(row["source_credibility"] or 0.0), + key_facts=_jsonb_to_str(row["key_facts"]), + risks=_jsonb_to_str(row["risks"]), + macro_themes=_jsonb_to_str(row["macro_themes"]), + schema_version=row["schema_version"] or "", + ) + refs.append(ref) + return refs + + +async def publish_market_snapshot_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> list[str]: + """Publish market bar/quote facts from a market_snapshots row.""" + row = await pool.fetchrow(_FETCH_MARKET_SNAPSHOT, entity_id) + if row is None: + logger.warning("Market snapshot %s not found", entity_id) + return [] + + ticker = row["ticker"] + data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]) + source = row["source_provider"] or "" + captured_at = row["captured_at"] + snapshot_type = row["snapshot_type"] + refs: list[str] = [] + + if snapshot_type == "bar" or snapshot_type == "bars": + # Single bar or list of bars + bars = data.get("bars", [data]) if "bars" in data else [data] + for bar in bars: + ref = publish_market_bar( + client=minio_client, + ticker=ticker, + open_price=float(bar.get("open", bar.get("o", 0))), + high_price=float(bar.get("high", bar.get("h", 0))), + low_price=float(bar.get("low", bar.get("l", 0))), + close_price=float(bar.get("close", bar.get("c", 0))), + volume=int(bar.get("volume", bar.get("v", 0))), + bar_timestamp=captured_at, + source=source, + vwap=float(bar.get("vwap", bar.get("vw", 0))), + trade_count=int(bar.get("trade_count", bar.get("n", 0))), + bar_interval=bar.get("interval", "1d"), + ) + refs.append(ref) + elif snapshot_type == "quote" or snapshot_type == "quotes": + ref = publish_market_quote( + client=minio_client, + ticker=ticker, + bid_price=float(data.get("bid_price", data.get("bp", 0))), + ask_price=float(data.get("ask_price", data.get("ap", 0))), + last_price=float(data.get("last_price", data.get("lp", 0))), + quote_at=captured_at, + source=source, + bid_size=int(data.get("bid_size", data.get("bs", 0))), + ask_size=int(data.get("ask_size", data.get("as", 0))), + last_size=int(data.get("last_size", data.get("ls", 0))), + ) + refs.append(ref) + + return refs + + +async def publish_order_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> str: + """Publish a trade order fact from PostgreSQL to the lake.""" + row = await pool.fetchrow(_FETCH_ORDER, entity_id) + if row is None: + logger.warning("Order %s not found", entity_id) + return "" + + submitted_at = row["submitted_at"] or datetime.now(timezone.utc) + return publish_trade_order( + client=minio_client, + order_id=str(row["id"]), + ticker=row["ticker"], + side=row["side"], + order_type=row["order_type"], + quantity=float(row["quantity"]), + limit_price=float(row["limit_price"]) if row["limit_price"] else None, + status=row["status"], + broker_account=row["broker_account"], + submitted_at=submitted_at, + recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else "", + execution_mode=row["execution_mode"], + ) + + +async def publish_fills_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> list[str]: + """Publish trade fill facts for an order from PostgreSQL to the lake.""" + rows = await pool.fetch(_FETCH_ORDER_FILLS, entity_id) + if not rows: + logger.info("No fill events for order %s", entity_id) + return [] + + refs: list[str] = [] + for row in rows: + data = row["data"] if isinstance(row["data"], dict) else json.loads(row["data"] or "{}") + filled_at = row["broker_timestamp"] or datetime.now(timezone.utc) + ref = publish_trade_fill( + client=minio_client, + fill_id=str(row["fill_id"]), + order_id=str(row["order_id"]), + ticker=row["ticker"], + side=row["side"], + fill_price=float(data.get("fill_price", data.get("price", 0))), + fill_quantity=float(data.get("fill_quantity", data.get("qty", 0))), + broker_account=row["broker_account"], + filled_at=filled_at, + commission=float(data.get("commission", 0)), + ) + refs.append(ref) + return refs + + +async def publish_positions_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> str: + """Publish daily position snapshots for a broker account.""" + rows = await pool.fetch(_FETCH_POSITIONS, entity_id) + if not rows: + logger.info("No open positions for account %s", entity_id) + return "" + + snapshot_at = datetime.now(timezone.utc) + positions = [ + { + "ticker": row["ticker"], + "quantity": float(row["quantity"]), + "avg_entry_price": float(row["avg_entry_price"] or 0), + "close_price": float(row["current_price"] or 0), + "unrealized_pnl": float(row["unrealized_pnl"] or 0), + } + for row in rows + ] + broker_account = rows[0]["broker_account"] if rows else "" + return publish_positions_daily_batch( + client=minio_client, + positions=positions, + broker_account=broker_account, + snapshot_at=snapshot_at, + ) + + +async def publish_pnl_job( + pool: asyncpg.Pool, + minio_client: Minio, + entity_id: str, +) -> list[str]: + """Publish daily PnL facts for a broker account's positions.""" + rows = await pool.fetch(_FETCH_POSITIONS, entity_id) + if not rows: + logger.info("No positions for PnL snapshot, account %s", entity_id) + return [] + + now = datetime.now(timezone.utc) + refs: list[str] = [] + for row in rows: + realized = float(row["realized_pnl"] or 0) + unrealized = float(row["unrealized_pnl"] or 0) + total = realized + unrealized + ref = publish_pnl_daily( + client=minio_client, + ticker=row["ticker"], + realized_pnl=realized, + unrealized_pnl=unrealized, + total_pnl=total, + broker_account=row["broker_account"], + dt=now, + execution_mode=row["execution_mode"], + ) + refs.append(ref) + return refs + + +async def publish_bulk_documents_job( + pool: asyncpg.Pool, + minio_client: Minio, + since: datetime, +) -> list[str]: + """Publish all documents created since a cutoff as a batch.""" + rows = await pool.fetch(_FETCH_BULK_DOCUMENTS, since) + if not rows: + logger.info("No documents to bulk-publish since %s", since) + return [] + + doc_rows: list[dict[str, object]] = [] + for row in rows: + published_at = row["published_at"] or row["retrieved_at"] + doc_rows.append({ + "document_id": str(row["id"]), + "document_type": row["document_type"], + "source_type": row["source_type"], + "ticker": row["ticker"] or "", + "publisher": row["publisher"] or "", + "title": row["title"] or "", + "url": row["url"] or "", + "canonical_url": row["canonical_url"] or "", + "language": row["language"] or "en", + "published_at": published_at, + "retrieved_at": row["retrieved_at"], + "content_hash": row["content_hash"], + "confidence": float(row["parse_quality_score"] or 0.0), + **partition_values(published_at), + }) + + ref = publish_documents_batch(minio_client, doc_rows, since) + return [ref] if ref else [] + + +async def publish_bulk_extractions_job( + pool: asyncpg.Pool, + minio_client: Minio, + since: datetime, +) -> list[str]: + """Publish all extractions created since a cutoff as a batch.""" + rows = await pool.fetch(_FETCH_BULK_EXTRACTIONS, since) + if not rows: + logger.info("No extractions to bulk-publish since %s", since) + return [] + + extraction_rows: list[dict[str, object]] = [] + for row in rows: + model_ver = row["schema_version"] or row["prompt_version"] or "" + extraction_rows.append({ + "document_id": str(row["document_id"]), + "ticker": row["ticker"], + "company_name": row["company_name"] or "", + "relevance": float(row["relevance"] or 0.0), + "sentiment": row["sentiment"] or "neutral", + "impact_score": float(row["impact_score"] or 0.0), + "impact_horizon": row["impact_horizon"] or "", + "catalyst_type": row["catalyst_type"] or "other", + "confidence": float(row["confidence"] or 0.0), + "novelty_score": float(row["novelty_score"] or 0.0), + "source_credibility": float(row["source_credibility"] or 0.0), + "key_facts": _jsonb_to_str(row["key_facts"]), + "risks": _jsonb_to_str(row["risks"]), + "macro_themes": _jsonb_to_str(row["macro_themes"]), + "model_name": row["model_name"] or "", + "prompt_version": row["prompt_version"] or "", + "schema_version": row["schema_version"] or "", + "extraction_at": row["extraction_at"], + **partition_values(row["extraction_at"], {"model_version": model_ver}), + }) + + model_ver = extraction_rows[0].get("model_version", "") if extraction_rows else "" + ref = publish_document_extractions_batch( + minio_client, extraction_rows, since, + model_version=str(model_ver), + ) + return [ref] if ref else [] + + +# --------------------------------------------------------------------------- +# Job dispatcher +# --------------------------------------------------------------------------- + +JOB_TYPES = { + "document", + "document_extraction", + "market_snapshot", + "trade_order", + "trade_fill", + "positions_snapshot", + "pnl_snapshot", + "company_event", + "bulk_documents", + "bulk_extractions", +} + + +async def dispatch_job( + pool: asyncpg.Pool, + minio_client: Minio, + job: dict[str, str], +) -> dict[str, object]: + """Dispatch a lake publish job to the appropriate handler. + + Args: + pool: PostgreSQL connection pool. + minio_client: MinIO client for writing Parquet files. + job: Job dict with at least 'job_type' and 'entity_id'. + + Returns: + A result dict with 'job_type', 'entity_id', 'refs' (list of s3 URIs), + and 'error' (None on success). + """ + job_type = job.get("job_type", "") + entity_id = job.get("entity_id", "") + since_str = job.get("since") + + result: dict[str, object] = { + "job_type": job_type, + "entity_id": entity_id, + "refs": [], + "error": None, + } + + try: + if job_type == "document": + ref = await publish_document_job(pool, minio_client, entity_id) + result["refs"] = [ref] if ref else [] + + elif job_type == "document_extraction": + refs = await publish_extraction_job(pool, minio_client, entity_id) + result["refs"] = refs + + elif job_type == "market_snapshot": + refs = await publish_market_snapshot_job(pool, minio_client, entity_id) + result["refs"] = refs + + elif job_type == "trade_order": + ref = await publish_order_job(pool, minio_client, entity_id) + result["refs"] = [ref] if ref else [] + + elif job_type == "trade_fill": + refs = await publish_fills_job(pool, minio_client, entity_id) + result["refs"] = refs + + elif job_type == "positions_snapshot": + ref = await publish_positions_job(pool, minio_client, entity_id) + result["refs"] = [ref] if ref else [] + + elif job_type == "pnl_snapshot": + refs = await publish_pnl_job(pool, minio_client, entity_id) + result["refs"] = refs + + elif job_type == "bulk_documents": + since = datetime.fromisoformat(since_str) if since_str else datetime.now(timezone.utc) + refs = await publish_bulk_documents_job(pool, minio_client, since) + result["refs"] = refs + + elif job_type == "bulk_extractions": + since = datetime.fromisoformat(since_str) if since_str else datetime.now(timezone.utc) + refs = await publish_bulk_extractions_job(pool, minio_client, since) + result["refs"] = refs + + else: + result["error"] = f"Unknown job_type: {job_type}" + logger.warning("Unknown lake publish job type: %s", job_type) + + except Exception as exc: + result["error"] = str(exc) + logger.exception("Lake publish job failed: %s/%s", job_type, entity_id) + + return result + + +# --------------------------------------------------------------------------- +# Async worker loop +# --------------------------------------------------------------------------- + + +async def run_worker( + pool: asyncpg.Pool, + rds: aioredis.Redis, + minio_client: Minio, + poll_interval: float = 2.0, +) -> None: + """Main worker loop — reads jobs from Redis and dispatches them. + + Runs indefinitely until cancelled. Each job is processed sequentially + to keep MinIO write ordering predictable. + """ + queue = queue_key(QUEUE_LAKE_PUBLISH) + logger.info("Lake publisher worker started, listening on %s", queue) + + while True: + raw = await rds.lpop(queue) # type: ignore[misc] + if raw is None: + await asyncio.sleep(poll_interval) + continue + + try: + job = json.loads(str(raw)) + except (json.JSONDecodeError, TypeError): + logger.error("Invalid lake publish job payload: %s", raw) + continue + + result = await dispatch_job(pool, minio_client, job) + refs = result.get("refs") or [] + error = result.get("error") + + if error: + logger.error( + "Lake publish job %s/%s failed: %s", + result["job_type"], result["entity_id"], error, + ) + else: + ref_count = len(refs) if isinstance(refs, list) else 0 + logger.info( + "Lake publish job %s/%s completed: %d facts written", + result["job_type"], result["entity_id"], ref_count, + ) + + +async def main() -> None: + """Entry point for the lake publisher worker process.""" + config = load_config() + pool = await get_pg_pool(config) + rds = get_redis(config) + minio_client = get_minio(config) + + try: + await run_worker(pool, rds, minio_client) + finally: + await pool.close() + await rds.close() + + +if __name__ == "__main__": + cfg = load_config() + setup_logging("lake_publisher", level=cfg.log_level, json_output=cfg.json_logs) + asyncio.run(main()) diff --git a/services/lake_publisher/partitions.py b/services/lake_publisher/partitions.py new file mode 100644 index 0000000..0b47cd7 --- /dev/null +++ b/services/lake_publisher/partitions.py @@ -0,0 +1,128 @@ +"""Hive-compatible partition layout conventions for the MinIO lakehouse. + +Centralizes partition path generation, partition column injection, and +bucket provisioning so that all lake publisher writers produce layouts +that Trino's Hive and Iceberg connectors can discover and prune. + +Design ref: Section 5.2, 5.3 (Lakehouse model) +Requirements: 9.4, 9.5, N4, N6 + +Layout convention: + s3://stonks-lakehouse/warehouse/{table_name}/dt={YYYY-MM-DD}[/{extra_key}={value}]/part-{uuid}.parquet + +Rules: + - Every fact table is partitioned by ``dt`` (DATE) derived from the row timestamp. + - Some tables have a second partition key (e.g. ``model_version``). + - Partition columns MUST appear in the Parquet file so Trino can read them + without relying solely on path parsing. + - File names use a UUID suffix to avoid collisions on concurrent writes. +""" +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import date, datetime, timezone + + +LAKEHOUSE_BUCKET = "stonks-lakehouse" +WAREHOUSE_PREFIX = "warehouse" + + +@dataclass(frozen=True) +class PartitionSpec: + """Describes the partition layout for a single fact table.""" + + table_name: str + extra_keys: tuple[str, ...] = field(default_factory=tuple) + + @property + def all_keys(self) -> tuple[str, ...]: + """Return all partition keys in order (dt first, then extras).""" + return ("dt", *self.extra_keys) + + +# Registry of every analytical fact table and its partition keys. +# This is the single source of truth — DDL, publisher, and tests should agree. +TABLE_PARTITIONS: dict[str, PartitionSpec] = { + "market_bars": PartitionSpec("market_bars"), + "market_quotes": PartitionSpec("market_quotes"), + "company_events": PartitionSpec("company_events"), + "documents": PartitionSpec("documents"), + "document_extractions": PartitionSpec("document_extractions", extra_keys=("model_version",)), + "trade_signals": PartitionSpec("trade_signals"), + "trade_orders": PartitionSpec("trade_orders"), + "trade_fills": PartitionSpec("trade_fills"), + "positions_daily": PartitionSpec("positions_daily"), + "pnl_daily": PartitionSpec("pnl_daily"), + "prediction_vs_outcome": PartitionSpec("prediction_vs_outcome", extra_keys=("model_version",)), + "model_performance": PartitionSpec("model_performance", extra_keys=("model_version",)), +} + + +def partition_path( + table_name: str, + dt: datetime | date, + extra_partitions: dict[str, str] | None = None, + file_id: str | None = None, +) -> str: + """Build a Hive-compatible object path for a Parquet file. + + Args: + table_name: Logical fact table name (must be in TABLE_PARTITIONS). + dt: Row timestamp or date used to derive the ``dt=`` partition. + extra_partitions: Additional partition key/value pairs (e.g. model_version). + file_id: Optional override for the file suffix (defaults to a UUID4). + + Returns: + Object key relative to the bucket root, e.g. + ``warehouse/trade_signals/dt=2026-04-11/part-<uuid>.parquet`` + """ + spec = TABLE_PARTITIONS.get(table_name) + if spec is None: + raise ValueError(f"Unknown table: {table_name}. Register it in TABLE_PARTITIONS.") + + if isinstance(dt, datetime): + dt_str = dt.strftime("%Y-%m-%d") + else: + dt_str = dt.isoformat() + + segments = [WAREHOUSE_PREFIX, table_name, f"dt={dt_str}"] + + # Append extra partition directories in the order declared by the spec. + extras = extra_partitions or {} + for key in spec.extra_keys: + value = extras.get(key, "__NONE__") + segments.append(f"{key}={value}") + + suffix = file_id or uuid.uuid4().hex[:16] + segments.append(f"part-{suffix}.parquet") + + return "/".join(segments) + + +def partition_values( + dt: datetime | date, + extra_partitions: dict[str, str] | None = None, +) -> dict[str, object]: + """Return partition column values to inject into Parquet row data. + + Trino's Hive connector can read partition values from the directory path, + but embedding them in the Parquet file as well ensures compatibility with + engines that don't parse Hive paths (e.g. plain PyArrow reads, DuckDB). + + Returns a dict like ``{"dt": date(2026, 4, 11), "model_version": "v2"}``. + """ + if isinstance(dt, datetime): + dt_date = dt.date() + else: + dt_date = dt + + values: dict[str, object] = {"dt": dt_date} + if extra_partitions: + values.update(extra_partitions) + return values + + +def s3_uri(path: str) -> str: + """Build an s3:// URI from a bucket-relative object path.""" + return f"s3://{LAKEHOUSE_BUCKET}/{path}" diff --git a/services/lake_publisher/worker.py b/services/lake_publisher/worker.py index 3db477a..c7dfe19 100644 --- a/services/lake_publisher/worker.py +++ b/services/lake_publisher/worker.py @@ -1 +1,1229 @@ -"""Lake publisher worker - writes partitioned Parquet facts to MinIO for Trino/Superset.""" +"""Lake publisher worker - writes partitioned Parquet facts to MinIO for Trino/Superset. + +Transforms operational recommendation and trend data into analytical fact datasets +stored as Parquet files in Hive-compatible partition layouts on MinIO. + +Requirements: 9.4, 9.5, 10.1 +Design ref: Section 4.10 (Lake Publisher), Section 7 (Analytical Lake Datasets) +""" +from __future__ import annotations + +import io +import logging +import re +import time +from datetime import datetime, timezone + +import pyarrow as pa +import pyarrow.parquet as pq +from minio import Minio + +from services.lake_publisher.partitions import ( + LAKEHOUSE_BUCKET, + partition_path, + partition_values, + s3_uri, +) +from services.shared.metrics import ( + LAKE_FACTS_PUBLISHED, + LAKE_PUBLISH_BYTES, + LAKE_PUBLISH_DURATION, + LAKE_PUBLISH_ERRORS, +) +from services.shared.schemas import Recommendation + +logger = logging.getLogger(__name__) + +# --- market_bars fact table --- + +MARKET_BARS_SCHEMA = pa.schema([ + ("ticker", pa.string()), + ("open_price", pa.float64()), + ("high_price", pa.float64()), + ("low_price", pa.float64()), + ("close_price", pa.float64()), + ("volume", pa.int64()), + ("vwap", pa.float64()), + ("trade_count", pa.int64()), + ("bar_timestamp", pa.timestamp("us", tz="UTC")), + ("bar_interval", pa.string()), + ("source", pa.string()), + ("dt", pa.date32()), +]) + +# --- market_quotes fact table --- + +MARKET_QUOTES_SCHEMA = pa.schema([ + ("ticker", pa.string()), + ("bid_price", pa.float64()), + ("ask_price", pa.float64()), + ("bid_size", pa.int64()), + ("ask_size", pa.int64()), + ("last_price", pa.float64()), + ("last_size", pa.int64()), + ("source", pa.string()), + ("quote_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), +]) + +# --- company_events fact table --- + +COMPANY_EVENTS_SCHEMA = pa.schema([ + ("event_id", pa.string()), + ("ticker", pa.string()), + ("event_type", pa.string()), + ("event_subtype", pa.string()), + ("title", pa.string()), + ("description", pa.string()), + ("source", pa.string()), + ("source_url", pa.string()), + ("event_at", pa.timestamp("us", tz="UTC")), + ("ingested_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), +]) + +# --- documents fact table --- + +DOCUMENTS_SCHEMA = pa.schema([ + ("document_id", pa.string()), + ("document_type", pa.string()), + ("source_type", pa.string()), + ("ticker", pa.string()), + ("publisher", pa.string()), + ("title", pa.string()), + ("url", pa.string()), + ("canonical_url", pa.string()), + ("language", pa.string()), + ("published_at", pa.timestamp("us", tz="UTC")), + ("retrieved_at", pa.timestamp("us", tz="UTC")), + ("content_hash", pa.string()), + ("confidence", pa.float64()), + ("dt", pa.date32()), +]) + +# --- document_extractions fact table --- + +DOCUMENT_EXTRACTIONS_SCHEMA = pa.schema([ + ("document_id", pa.string()), + ("ticker", pa.string()), + ("company_name", pa.string()), + ("relevance", pa.float64()), + ("sentiment", pa.string()), + ("impact_score", pa.float64()), + ("impact_horizon", pa.string()), + ("catalyst_type", pa.string()), + ("confidence", pa.float64()), + ("novelty_score", pa.float64()), + ("source_credibility", pa.float64()), + ("key_facts", pa.string()), + ("risks", pa.string()), + ("macro_themes", pa.string()), + ("model_name", pa.string()), + ("prompt_version", pa.string()), + ("schema_version", pa.string()), + ("extraction_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), + ("model_version", pa.string()), +]) + +# --- trade_signals fact table --- + +TRADE_SIGNALS_SCHEMA = pa.schema([ + ("signal_id", pa.string()), + ("ticker", pa.string()), + ("trend_direction", pa.string()), + ("trend_strength", pa.float64()), + ("confidence", pa.float64()), + ("contradiction_score", pa.float64()), + ("dominant_catalysts", pa.string()), + ("material_risks", pa.string()), + ("action", pa.string()), + ("time_horizon", pa.string()), + ("recommendation_id", pa.string()), + ("generated_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), +]) + + +def build_trade_signal_row( + rec: Recommendation, + trend_direction: str = "", + trend_strength: float = 0.0, + contradiction_score: float = 0.0, + dominant_catalysts: str = "", + material_risks: str = "", +) -> dict[str, object]: + """Build a single trade_signals fact row from a Recommendation and its trend context.""" + return { + "signal_id": rec.recommendation_id, + "ticker": rec.ticker, + "trend_direction": trend_direction, + "trend_strength": trend_strength, + "confidence": rec.confidence, + "contradiction_score": contradiction_score, + "dominant_catalysts": dominant_catalysts, + "material_risks": material_risks, + "action": rec.action.value, + "time_horizon": rec.time_horizon, + "recommendation_id": rec.recommendation_id, + "generated_at": rec.generated_at, + **partition_values(rec.generated_at), + } + + +def _write_parquet_bytes(table: pa.Table) -> bytes: + """Serialize a PyArrow table to Parquet bytes.""" + buf = io.BytesIO() + pq.write_table(table, buf) + return buf.getvalue() + + +def _put_lakehouse_object( + client: Minio, + table_name: str, + path: str, + parquet_bytes: bytes, +) -> None: + """Write a Parquet file to MinIO and record Prometheus metrics.""" + _start = time.monotonic() + client.put_object( + LAKEHOUSE_BUCKET, + path, + io.BytesIO(parquet_bytes), + length=len(parquet_bytes), + content_type="application/octet-stream", + ) + LAKE_PUBLISH_DURATION.labels(table_name=table_name).observe(time.monotonic() - _start) + LAKE_FACTS_PUBLISHED.labels(table_name=table_name).inc() + LAKE_PUBLISH_BYTES.labels(table_name=table_name).inc(len(parquet_bytes)) + + +def _partition_path(table_name: str, dt: datetime, extra_partitions: dict[str, str] | None = None) -> str: + """Build a Hive-compatible partition path. + + Delegates to services.lake_publisher.partitions for the canonical implementation. + Kept for backward compatibility with existing callers. + """ + return partition_path(table_name, dt, extra_partitions) + + +def publish_trade_signal( + client: Minio, + rec: Recommendation, + trend_direction: str = "", + trend_strength: float = 0.0, + contradiction_score: float = 0.0, + dominant_catalysts: str = "", + material_risks: str = "", +) -> str: + """Publish a single recommendation as a trade_signals fact to MinIO. + + Writes a Parquet file to the Hive-compatible partition layout: + s3://stonks-lakehouse/warehouse/trade_signals/dt={date}/part-{ts}.parquet + + Returns the s3:// URI of the written object. + """ + row = build_trade_signal_row( + rec, trend_direction, trend_strength, + contradiction_score, dominant_catalysts, material_risks, + ) + table = pa.Table.from_pylist([row], schema=TRADE_SIGNALS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("trade_signals", rec.generated_at) + _put_lakehouse_object(client, "trade_signals", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published trade_signal fact for %s: %s", rec.ticker, ref) + return ref + + +# --- prediction_vs_outcome fact table (skeleton for Phase 10+) --- + +PREDICTION_VS_OUTCOME_SCHEMA = pa.schema([ + ("recommendation_id", pa.string()), + ("ticker", pa.string()), + ("predicted_action", pa.string()), + ("predicted_confidence", pa.float64()), + ("actual_move_pct", pa.float64()), + ("outcome", pa.string()), + ("horizon_days", pa.int32()), + ("predicted_at", pa.timestamp("us", tz="UTC")), + ("evaluated_at", pa.timestamp("us", tz="UTC")), + ("model_version", pa.string()), + ("dt", pa.date32()), +]) + + +def publish_prediction_fact( + client: Minio, + rec: Recommendation, + trend_direction: str = "", + trend_strength: float = 0.0, +) -> str: + """Publish a prediction fact for a recommendation. + + This writes the prediction side of the prediction_vs_outcome table. + The outcome fields (actual_move_pct, outcome, evaluated_at) are left + as placeholders — they get backfilled when market outcomes are known. + + Returns the s3:// URI of the written Parquet file. + """ + # Parse horizon days from time_horizon string (e.g. "swing_1d_10d" -> 10) + horizon_days = _parse_horizon_days(rec.time_horizon) + + model_ver = getattr(rec.model_metadata, "model_name", "") if rec.model_metadata else "" + extra = {"model_version": model_ver} + + row = { + "recommendation_id": rec.recommendation_id, + "ticker": rec.ticker, + "predicted_action": rec.action.value, + "predicted_confidence": rec.confidence, + "actual_move_pct": None, + "outcome": "pending", + "horizon_days": horizon_days, + "predicted_at": rec.generated_at, + "evaluated_at": None, + **partition_values(rec.generated_at, extra), + } + + table = pa.Table.from_pylist([row], schema=PREDICTION_VS_OUTCOME_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("prediction_vs_outcome", rec.generated_at, extra) + _put_lakehouse_object(client, "prediction_vs_outcome", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published prediction_vs_outcome fact for %s: %s", rec.ticker, ref) + return ref + + +def _parse_horizon_days(time_horizon: str) -> int: + """Extract the max horizon days from a time_horizon string. + + Examples: + "swing_1d_10d" -> 10 + "position_10d_30d" -> 30 + "scalp_intraday" -> 1 + "" -> 0 + """ + if not time_horizon: + return 0 + if "intraday" in time_horizon: + return 1 + numbers = re.findall(r"(\d+)", time_horizon) + if numbers: + return max(int(n) for n in numbers) + return 0 + + +def publish_recommendation_facts( + client: Minio, + rec: Recommendation, + trend_direction: str = "", + trend_strength: float = 0.0, + contradiction_score: float = 0.0, + dominant_catalysts: str = "", + material_risks: str = "", +) -> dict[str, str]: + """Publish all analytical facts for a recommendation. + + Writes both trade_signals and prediction_vs_outcome facts. + Returns a dict mapping table name to s3:// URI. + """ + refs: dict[str, str] = {} + + refs["trade_signals"] = publish_trade_signal( + client, rec, trend_direction, trend_strength, + contradiction_score, dominant_catalysts, material_risks, + ) + refs["prediction_vs_outcome"] = publish_prediction_fact( + client, rec, trend_direction, trend_strength, + ) + + return refs + + +# --- trade_orders fact table --- + +TRADE_ORDERS_SCHEMA = pa.schema([ + ("order_id", pa.string()), + ("recommendation_id", pa.string()), + ("ticker", pa.string()), + ("side", pa.string()), + ("order_type", pa.string()), + ("quantity", pa.float64()), + ("limit_price", pa.float64()), + ("status", pa.string()), + ("execution_mode", pa.string()), + ("broker_account", pa.string()), + ("submitted_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), +]) + + +def build_trade_order_row( + order_id: str, + ticker: str, + side: str, + order_type: str, + quantity: float, + limit_price: float | None, + status: str, + broker_account: str, + submitted_at: datetime, + recommendation_id: str = "", + execution_mode: str = "paper", +) -> dict[str, object]: + """Build a single trade_orders fact row.""" + return { + "order_id": order_id, + "recommendation_id": recommendation_id, + "ticker": ticker, + "side": side, + "order_type": order_type, + "quantity": quantity, + "limit_price": limit_price, + "status": status, + "execution_mode": execution_mode, + "broker_account": broker_account, + "submitted_at": submitted_at, + **partition_values(submitted_at), + } + + +def publish_trade_order( + client: Minio, + order_id: str, + ticker: str, + side: str, + order_type: str, + quantity: float, + limit_price: float | None, + status: str, + broker_account: str, + submitted_at: datetime, + recommendation_id: str = "", + execution_mode: str = "paper", +) -> str: + """Publish a single order as a trade_orders fact to MinIO. + + Returns the s3:// URI of the written object. + + Requirements: 9.4, 9.5 + Design ref: Section 7 (lake.trade_orders) + """ + row = build_trade_order_row( + order_id, ticker, side, order_type, quantity, + limit_price, status, broker_account, submitted_at, + recommendation_id, execution_mode, + ) + table = pa.Table.from_pylist([row], schema=TRADE_ORDERS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("trade_orders", submitted_at) + _put_lakehouse_object(client, "trade_orders", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published trade_order fact for %s: %s", ticker, ref) + return ref + + +# --- trade_fills fact table --- + +TRADE_FILLS_SCHEMA = pa.schema([ + ("fill_id", pa.string()), + ("order_id", pa.string()), + ("ticker", pa.string()), + ("side", pa.string()), + ("fill_price", pa.float64()), + ("fill_quantity", pa.float64()), + ("commission", pa.float64()), + ("broker_account", pa.string()), + ("filled_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), +]) + + +def build_trade_fill_row( + fill_id: str, + order_id: str, + ticker: str, + side: str, + fill_price: float, + fill_quantity: float, + broker_account: str, + filled_at: datetime, + commission: float = 0.0, +) -> dict[str, object]: + """Build a single trade_fills fact row.""" + return { + "fill_id": fill_id, + "order_id": order_id, + "ticker": ticker, + "side": side, + "fill_price": fill_price, + "fill_quantity": fill_quantity, + "commission": commission, + "broker_account": broker_account, + "filled_at": filled_at, + **partition_values(filled_at), + } + + +def publish_trade_fill( + client: Minio, + fill_id: str, + order_id: str, + ticker: str, + side: str, + fill_price: float, + fill_quantity: float, + broker_account: str, + filled_at: datetime, + commission: float = 0.0, +) -> str: + """Publish a single fill as a trade_fills fact to MinIO. + + Returns the s3:// URI of the written object. + + Requirements: 9.4, 9.5 + Design ref: Section 7 (lake.trade_fills) + """ + row = build_trade_fill_row( + fill_id, order_id, ticker, side, + fill_price, fill_quantity, broker_account, filled_at, + commission, + ) + table = pa.Table.from_pylist([row], schema=TRADE_FILLS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("trade_fills", filled_at) + _put_lakehouse_object(client, "trade_fills", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published trade_fill fact for %s: %s", ticker, ref) + return ref + + +# --- positions_daily fact table --- + +POSITIONS_DAILY_SCHEMA = pa.schema([ + ("ticker", pa.string()), + ("quantity", pa.float64()), + ("avg_entry_price", pa.float64()), + ("close_price", pa.float64()), + ("market_value", pa.float64()), + ("unrealized_pnl", pa.float64()), + ("broker_account", pa.string()), + ("execution_mode", pa.string()), + ("snapshot_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), +]) + + +def build_position_daily_row( + ticker: str, + quantity: float, + avg_entry_price: float, + close_price: float, + unrealized_pnl: float, + broker_account: str, + snapshot_at: datetime, + market_value: float = 0.0, + execution_mode: str = "paper", +) -> dict[str, object]: + """Build a single positions_daily fact row.""" + return { + "ticker": ticker, + "quantity": quantity, + "avg_entry_price": avg_entry_price, + "close_price": close_price, + "market_value": market_value, + "unrealized_pnl": unrealized_pnl, + "broker_account": broker_account, + "execution_mode": execution_mode, + "snapshot_at": snapshot_at, + **partition_values(snapshot_at), + } + + +def publish_position_daily( + client: Minio, + ticker: str, + quantity: float, + avg_entry_price: float, + close_price: float, + unrealized_pnl: float, + broker_account: str, + snapshot_at: datetime, +) -> str: + """Publish a single position snapshot as a positions_daily fact to MinIO. + + Returns the s3:// URI of the written object. + + Requirements: 9.4, 9.5 + Design ref: Section 7 (lake.positions_daily) + """ + row = build_position_daily_row( + ticker, quantity, avg_entry_price, close_price, + unrealized_pnl, broker_account, snapshot_at, + ) + table = pa.Table.from_pylist([row], schema=POSITIONS_DAILY_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("positions_daily", snapshot_at) + _put_lakehouse_object(client, "positions_daily", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published positions_daily fact for %s: %s", ticker, ref) + return ref + + +def publish_positions_daily_batch( + client: Minio, + positions: list[dict], + broker_account: str, + snapshot_at: datetime, +) -> str: + """Publish a batch of position snapshots as a single Parquet file. + + Each dict in positions should have: ticker, quantity, avg_entry_price, + close_price, unrealized_pnl. + + Returns the s3:// URI of the written object. + """ + rows = [ + build_position_daily_row( + ticker=p["ticker"], + quantity=p["quantity"], + avg_entry_price=p["avg_entry_price"], + close_price=p["close_price"], + unrealized_pnl=p["unrealized_pnl"], + broker_account=broker_account, + snapshot_at=snapshot_at, + ) + for p in positions + ] + if not rows: + logger.info("No positions to publish for positions_daily") + return "" + + table = pa.Table.from_pylist(rows, schema=POSITIONS_DAILY_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("positions_daily", snapshot_at) + _put_lakehouse_object(client, "positions_daily", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published %d positions_daily facts: %s", len(rows), ref) + return ref + + +# --- pnl_daily fact table --- + +PNL_DAILY_SCHEMA = pa.schema([ + ("ticker", pa.string()), + ("realized_pnl", pa.float64()), + ("unrealized_pnl", pa.float64()), + ("total_pnl", pa.float64()), + ("fees", pa.float64()), + ("net_pnl", pa.float64()), + ("broker_account", pa.string()), + ("execution_mode", pa.string()), + ("dt", pa.date32()), +]) + + +def build_pnl_daily_row( + ticker: str, + realized_pnl: float, + unrealized_pnl: float, + total_pnl: float, + broker_account: str, + dt: datetime | None = None, + fees: float = 0.0, + net_pnl: float | None = None, + execution_mode: str = "paper", +) -> dict[str, object]: + """Build a single pnl_daily fact row.""" + row_dt = dt or datetime.now(timezone.utc) + return { + "ticker": ticker, + "realized_pnl": realized_pnl, + "unrealized_pnl": unrealized_pnl, + "total_pnl": total_pnl, + "fees": fees, + "net_pnl": net_pnl if net_pnl is not None else total_pnl - fees, + "broker_account": broker_account, + "execution_mode": execution_mode, + **partition_values(row_dt), + } + + +def publish_pnl_daily( + client: Minio, + ticker: str, + realized_pnl: float, + unrealized_pnl: float, + total_pnl: float, + broker_account: str, + dt: datetime, + fees: float = 0.0, + net_pnl: float | None = None, + execution_mode: str = "paper", +) -> str: + """Publish a single pnl_daily fact to MinIO. + + Returns the s3:// URI of the written object. + + Requirements: 9.4, 9.5 + Design ref: Section 7 (lake.pnl_daily) + """ + row = build_pnl_daily_row( + ticker, realized_pnl, unrealized_pnl, total_pnl, + broker_account, dt=dt, fees=fees, net_pnl=net_pnl, execution_mode=execution_mode, + ) + table = pa.Table.from_pylist([row], schema=PNL_DAILY_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("pnl_daily", dt) + _put_lakehouse_object(client, "pnl_daily", path, parquet_bytes) + + ref = s3_uri(path) + logger.info("Published pnl_daily fact for %s: %s", ticker, ref) + return ref + + +# --- market_bars publisher --- + +def publish_market_bar( + client: Minio, + ticker: str, + open_price: float, + high_price: float, + low_price: float, + close_price: float, + volume: int, + bar_timestamp: datetime, + source: str, + vwap: float = 0.0, + trade_count: int = 0, + bar_interval: str = "1d", +) -> str: + """Publish a single market bar fact to MinIO. + + Requirements: 2.1, 9.4, 9.5 + Design ref: Section 7 (lake.market_bars) + """ + row: dict[str, object] = { + "ticker": ticker, + "open_price": open_price, + "high_price": high_price, + "low_price": low_price, + "close_price": close_price, + "volume": volume, + "vwap": vwap, + "trade_count": trade_count, + "bar_timestamp": bar_timestamp, + "bar_interval": bar_interval, + "source": source, + **partition_values(bar_timestamp), + } + table = pa.Table.from_pylist([row], schema=MARKET_BARS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("market_bars", bar_timestamp) + _put_lakehouse_object(client, "market_bars", path, parquet_bytes) + ref = s3_uri(path) + logger.info("Published market_bar fact for %s: %s", ticker, ref) + return ref + + +# --- market_quotes publisher --- + +def publish_market_quote( + client: Minio, + ticker: str, + bid_price: float, + ask_price: float, + last_price: float, + quote_at: datetime, + source: str, + bid_size: int = 0, + ask_size: int = 0, + last_size: int = 0, +) -> str: + """Publish a single market quote fact to MinIO. + + Requirements: 2.1, 9.4, 9.5 + Design ref: Section 7 (lake.market_quotes) + """ + row: dict[str, object] = { + "ticker": ticker, + "bid_price": bid_price, + "ask_price": ask_price, + "bid_size": bid_size, + "ask_size": ask_size, + "last_price": last_price, + "last_size": last_size, + "source": source, + "quote_at": quote_at, + **partition_values(quote_at), + } + table = pa.Table.from_pylist([row], schema=MARKET_QUOTES_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("market_quotes", quote_at) + _put_lakehouse_object(client, "market_quotes", path, parquet_bytes) + ref = s3_uri(path) + logger.info("Published market_quote fact for %s: %s", ticker, ref) + return ref + + +# --- company_events publisher --- + +def publish_company_event( + client: Minio, + event_id: str, + ticker: str, + event_type: str, + title: str, + event_at: datetime, + source: str, + event_subtype: str = "", + description: str = "", + source_url: str = "", + ingested_at: datetime | None = None, +) -> str: + """Publish a single company event fact to MinIO. + + Requirements: 2.3, 9.4, 9.5 + Design ref: Section 7 (lake.company_events) + """ + row: dict[str, object] = { + "event_id": event_id, + "ticker": ticker, + "event_type": event_type, + "event_subtype": event_subtype, + "title": title, + "description": description, + "source": source, + "source_url": source_url, + "event_at": event_at, + "ingested_at": ingested_at or datetime.now(timezone.utc), + **partition_values(event_at), + } + table = pa.Table.from_pylist([row], schema=COMPANY_EVENTS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("company_events", event_at) + _put_lakehouse_object(client, "company_events", path, parquet_bytes) + ref = s3_uri(path) + logger.info("Published company_event fact for %s: %s", ticker, ref) + return ref + + +# --- documents publisher --- + +def publish_document_fact( + client: Minio, + document_id: str, + document_type: str, + source_type: str, + ticker: str, + publisher: str, + title: str, + published_at: datetime, + content_hash: str, + url: str = "", + canonical_url: str = "", + language: str = "en", + confidence: float = 0.0, + retrieved_at: datetime | None = None, +) -> str: + """Publish a single document metadata fact to MinIO. + + Requirements: 3.1, 3.3, 9.4, 9.5 + Design ref: Section 6.2, Section 7 (lake.documents) + """ + row: dict[str, object] = { + "document_id": document_id, + "document_type": document_type, + "source_type": source_type, + "ticker": ticker, + "publisher": publisher, + "title": title, + "url": url, + "canonical_url": canonical_url, + "language": language, + "published_at": published_at, + "retrieved_at": retrieved_at or datetime.now(timezone.utc), + "content_hash": content_hash, + "confidence": confidence, + **partition_values(published_at), + } + table = pa.Table.from_pylist([row], schema=DOCUMENTS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path("documents", published_at) + _put_lakehouse_object(client, "documents", path, parquet_bytes) + ref = s3_uri(path) + logger.info("Published document fact for %s: %s", ticker, ref) + return ref + + +# --- document_extractions publisher --- + +def publish_document_extraction( + client: Minio, + document_id: str, + ticker: str, + sentiment: str, + impact_score: float, + catalyst_type: str, + confidence: float, + extraction_at: datetime, + model_name: str, + prompt_version: str, + company_name: str = "", + relevance: float = 0.0, + impact_horizon: str = "", + novelty_score: float = 0.0, + source_credibility: float = 0.0, + key_facts: str = "", + risks: str = "", + macro_themes: str = "", + schema_version: str = "", +) -> str: + """Publish a single document extraction fact to MinIO. + + Requirements: 5.3, 5.5, 9.4, 9.5 + Design ref: Section 6.3, Section 7 (lake.document_extractions) + """ + model_ver = schema_version or prompt_version + extra = {"model_version": model_ver} + row: dict[str, object] = { + "document_id": document_id, + "ticker": ticker, + "company_name": company_name, + "relevance": relevance, + "sentiment": sentiment, + "impact_score": impact_score, + "impact_horizon": impact_horizon, + "catalyst_type": catalyst_type, + "confidence": confidence, + "novelty_score": novelty_score, + "source_credibility": source_credibility, + "key_facts": key_facts, + "risks": risks, + "macro_themes": macro_themes, + "model_name": model_name, + "prompt_version": prompt_version, + "schema_version": schema_version, + "extraction_at": extraction_at, + **partition_values(extraction_at, extra), + } + table = pa.Table.from_pylist([row], schema=DOCUMENT_EXTRACTIONS_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path( + "document_extractions", extraction_at, + extra_partitions=extra, + ) + _put_lakehouse_object(client, "document_extractions", path, parquet_bytes) + ref = s3_uri(path) + logger.info("Published document_extraction fact for %s/%s: %s", ticker, document_id, ref) + return ref + + +# --- model_performance fact table --- + +MODEL_PERFORMANCE_SCHEMA = pa.schema([ + ("document_id", pa.string()), + ("ticker", pa.string()), + ("model_name", pa.string()), + ("prompt_version", pa.string()), + ("schema_version", pa.string()), + ("success", pa.bool_()), + ("attempt_count", pa.int32()), + ("total_duration_ms", pa.int32()), + ("first_attempt_duration_ms", pa.int32()), + ("final_attempt_duration_ms", pa.int32()), + ("confidence", pa.float64()), + ("validation_status", pa.string()), + ("validation_error_count", pa.int32()), + ("validation_warning_count", pa.int32()), + ("retry_count", pa.int32()), + ("input_token_estimate", pa.int32()), + ("output_token_estimate", pa.int32()), + ("company_count", pa.int32()), + ("recorded_at", pa.timestamp("us", tz="UTC")), + ("dt", pa.date32()), + ("model_version", pa.string()), +]) + + +def build_model_performance_row( + document_id: str, + model_name: str, + success: bool, + total_duration_ms: int, + recorded_at: datetime, + ticker: str = "", + prompt_version: str = "", + schema_version: str = "", + attempt_count: int = 1, + first_attempt_duration_ms: int = 0, + final_attempt_duration_ms: int = 0, + confidence: float = 0.0, + validation_status: str = "unknown", + validation_error_count: int = 0, + validation_warning_count: int = 0, + retry_count: int = 0, + input_token_estimate: int = 0, + output_token_estimate: int = 0, + company_count: int = 0, +) -> dict[str, object]: + """Build a single model_performance fact row.""" + model_ver = schema_version or prompt_version or model_name + return { + "document_id": document_id, + "ticker": ticker, + "model_name": model_name, + "prompt_version": prompt_version, + "schema_version": schema_version, + "success": success, + "attempt_count": attempt_count, + "total_duration_ms": total_duration_ms, + "first_attempt_duration_ms": first_attempt_duration_ms, + "final_attempt_duration_ms": final_attempt_duration_ms, + "confidence": confidence, + "validation_status": validation_status, + "validation_error_count": validation_error_count, + "validation_warning_count": validation_warning_count, + "retry_count": retry_count, + "input_token_estimate": input_token_estimate, + "output_token_estimate": output_token_estimate, + "company_count": company_count, + "recorded_at": recorded_at, + **partition_values(recorded_at, {"model_version": model_ver}), + } + + +def publish_model_performance( + client: Minio, + document_id: str, + model_name: str, + success: bool, + total_duration_ms: int, + recorded_at: datetime, + ticker: str = "", + prompt_version: str = "", + schema_version: str = "", + attempt_count: int = 1, + first_attempt_duration_ms: int = 0, + final_attempt_duration_ms: int = 0, + confidence: float = 0.0, + validation_status: str = "unknown", + validation_error_count: int = 0, + validation_warning_count: int = 0, + retry_count: int = 0, + input_token_estimate: int = 0, + output_token_estimate: int = 0, + company_count: int = 0, +) -> str: + """Publish a single model performance fact to MinIO. + + Requirements: 12.1, 12.2, 9.4, 9.5 + Design ref: Section 7 (lake.model_performance) + """ + row = build_model_performance_row( + document_id=document_id, + model_name=model_name, + success=success, + total_duration_ms=total_duration_ms, + recorded_at=recorded_at, + ticker=ticker, + prompt_version=prompt_version, + schema_version=schema_version, + attempt_count=attempt_count, + first_attempt_duration_ms=first_attempt_duration_ms, + final_attempt_duration_ms=final_attempt_duration_ms, + confidence=confidence, + validation_status=validation_status, + validation_error_count=validation_error_count, + validation_warning_count=validation_warning_count, + retry_count=retry_count, + input_token_estimate=input_token_estimate, + output_token_estimate=output_token_estimate, + company_count=company_count, + ) + model_ver = schema_version or prompt_version or model_name + table = pa.Table.from_pylist([row], schema=MODEL_PERFORMANCE_SCHEMA) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path( + "model_performance", recorded_at, + extra_partitions={"model_version": model_ver}, + ) + _put_lakehouse_object(client, "model_performance", path, parquet_bytes) + ref = s3_uri(path) + logger.info("Published model_performance fact for %s/%s: %s", model_name, document_id, ref) + return ref + + + +# --------------------------------------------------------------------------- +# Batch publish helpers +# --------------------------------------------------------------------------- + +def _publish_batch( + client: Minio, + table_name: str, + rows: list[dict[str, object]], + schema: pa.Schema, + dt: datetime, + extra_partitions: dict[str, str] | None = None, +) -> str: + """Generic batch publisher — writes a list of row dicts as a single Parquet file. + + Returns the s3:// URI of the written object, or "" if rows is empty. + """ + if not rows: + logger.info("No rows to publish for %s", table_name) + return "" + + # Inject partition columns into rows that don't already have them. + pv = partition_values(dt, extra_partitions) + enriched = [] + for row in rows: + merged = {**row} + for k, v in pv.items(): + if k not in merged: + merged[k] = v + enriched.append(merged) + + table = pa.Table.from_pylist(enriched, schema=schema) + parquet_bytes = _write_parquet_bytes(table) + + path = _partition_path(table_name, dt, extra_partitions) + _pub_start = time.monotonic() + client.put_object( + LAKEHOUSE_BUCKET, path, + io.BytesIO(parquet_bytes), length=len(parquet_bytes), + content_type="application/octet-stream", + ) + LAKE_PUBLISH_DURATION.labels(table_name=table_name).observe(time.monotonic() - _pub_start) + LAKE_FACTS_PUBLISHED.labels(table_name=table_name).inc(len(enriched)) + LAKE_PUBLISH_BYTES.labels(table_name=table_name).inc(len(parquet_bytes)) + ref = s3_uri(path) + logger.info("Published %d %s facts: %s", len(enriched), table_name, ref) + return ref + + +def publish_market_bars_batch( + client: Minio, + bars: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of market bar rows as a single Parquet file. + + Each dict should match MARKET_BARS_SCHEMA field names. + """ + return _publish_batch(client, "market_bars", bars, MARKET_BARS_SCHEMA, dt) + + +def publish_market_quotes_batch( + client: Minio, + quotes: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of market quote rows as a single Parquet file.""" + return _publish_batch(client, "market_quotes", quotes, MARKET_QUOTES_SCHEMA, dt) + + +def publish_company_events_batch( + client: Minio, + events: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of company event rows as a single Parquet file.""" + return _publish_batch(client, "company_events", events, COMPANY_EVENTS_SCHEMA, dt) + + +def publish_documents_batch( + client: Minio, + docs: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of document metadata rows as a single Parquet file.""" + return _publish_batch(client, "documents", docs, DOCUMENTS_SCHEMA, dt) + + +def publish_document_extractions_batch( + client: Minio, + extractions: list[dict[str, object]], + dt: datetime, + model_version: str = "", +) -> str: + """Publish a batch of document extraction rows as a single Parquet file.""" + extra = {"model_version": model_version} if model_version else None + return _publish_batch(client, "document_extractions", extractions, DOCUMENT_EXTRACTIONS_SCHEMA, dt, extra) + + +def publish_trade_signals_batch( + client: Minio, + signals: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of trade signal rows as a single Parquet file.""" + return _publish_batch(client, "trade_signals", signals, TRADE_SIGNALS_SCHEMA, dt) + + +def publish_trade_orders_batch( + client: Minio, + orders: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of trade order rows as a single Parquet file.""" + return _publish_batch(client, "trade_orders", orders, TRADE_ORDERS_SCHEMA, dt) + + +def publish_trade_fills_batch( + client: Minio, + fills: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of trade fill rows as a single Parquet file.""" + return _publish_batch(client, "trade_fills", fills, TRADE_FILLS_SCHEMA, dt) + + +def publish_pnl_daily_batch( + client: Minio, + rows: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of PnL daily rows as a single Parquet file.""" + return _publish_batch(client, "pnl_daily", rows, PNL_DAILY_SCHEMA, dt) + + +def publish_model_performance_batch( + client: Minio, + rows: list[dict[str, object]], + dt: datetime, + model_version: str = "", +) -> str: + """Publish a batch of model performance rows as a single Parquet file.""" + extra = {"model_version": model_version} if model_version else None + return _publish_batch(client, "model_performance", rows, MODEL_PERFORMANCE_SCHEMA, dt, extra) + + +def publish_prediction_vs_outcome_batch( + client: Minio, + rows: list[dict[str, object]], + dt: datetime, +) -> str: + """Publish a batch of prediction vs outcome rows as a single Parquet file.""" + return _publish_batch(client, "prediction_vs_outcome", rows, PREDICTION_VS_OUTCOME_SCHEMA, dt) diff --git a/services/parser/html_parser.py b/services/parser/html_parser.py new file mode 100644 index 0000000..1b96760 --- /dev/null +++ b/services/parser/html_parser.py @@ -0,0 +1,858 @@ +"""HTML-to-text parsing pipeline using BeautifulSoup. + +Provides structured HTML parsing with boilerplate removal, metadata extraction, +outbound link extraction, and quality scoring. Inspired by Noctipede crawler +patterns: BeautifulSoup + content hashing, boilerplate stripping, quality scoring. + +Requirements: 4.1, 4.2, 4.3 +""" +from __future__ import annotations + +import json +import logging +import math +import re +from dataclasses import dataclass, field +from urllib.parse import urlparse + +from bs4 import BeautifulSoup, Tag + +logger = logging.getLogger("html_parser") + +# Tags that never contain useful article content +STRIP_TAGS = [ + "script", "style", "nav", "footer", "header", "aside", + "iframe", "noscript", "svg", "form", "button", +] + +# CSS class / id substrings that signal boilerplate containers +BOILERPLATE_SIGNALS = [ + "sidebar", "widget", "advert", "promo", "newsletter", + "social-share", "share-bar", "related-posts", "comment", + "cookie", "popup", "modal", "banner", "breadcrumb", + "pagination", "nav-", "menu", "toolbar", "signup", + "subscribe", "follow-us", "social-media", "share-button", + "ad-slot", "ad-container", "sponsored", +] + +# Regex patterns for residual boilerplate in extracted text +BOILERPLATE_TEXT_PATTERNS = [ + re.compile(r"(?i)subscribe to our newsletter.*?(?:\n|$)"), + re.compile(r"(?i)click here to read more.*?(?:\n|$)"), + re.compile(r"(?i)advertisement\s*\n?"), + re.compile(r"(?i)copyright ©.*?(?:\n|$)"), + re.compile(r"(?i)all rights reserved.*?(?:\n|$)"), + re.compile(r"(?i)terms of (use|service).*?(?:\n|$)"), + re.compile(r"(?i)privacy policy.*?(?:\n|$)"), + re.compile(r"\s*\[.*?ad.*?\]\s*", re.IGNORECASE), + re.compile(r"(?i)sign up for .*?(?:\n|$)"), + re.compile(r"(?i)follow us on .*?(?:\n|$)"), + re.compile(r"(?i)share this (article|story|post).*?(?:\n|$)"), + re.compile(r"(?i)read more:?\s*$"), + re.compile(r"(?i)recommended for you.*?(?:\n|$)"), + re.compile(r"(?i)you may also like.*?(?:\n|$)"), + re.compile(r"(?i)trending now.*?(?:\n|$)"), + re.compile(r"(?i)most (popular|read).*?(?:\n|$)"), + re.compile(r"(?i)^tags:\s*$"), + re.compile(r"(?i)^\s*photo\s*:.*?(?:\n|$)"), + re.compile(r"(?i)^\s*image\s*(credit|source|courtesy)\s*:.*?(?:\n|$)"), +] + +# Selectors for article body candidates, in priority order +ARTICLE_SELECTORS = [ + "article", + "[role='main']", + ".article-body", + ".post-content", + ".entry-content", + ".story-body", + ".article-content", + "#article-body", + "#story-body", + ".article-text", + ".post-body", + ".content-body", + "main", +] + +# Minimum text density (text chars / total chars including markup) for a block +# to be considered content-rich rather than boilerplate +_MIN_TEXT_DENSITY = 0.25 + +# Minimum word count for a block to be a viable body candidate +_MIN_BLOCK_WORDS = 20 + + +@dataclass +class QualitySignals: + """Individual quality signals contributing to the overall parse score. + + Each signal is a float in [0, 1] representing how well the parsed + content performs on that dimension. + + Requirements: 4.3 + """ + word_count_signal: float = 0.0 + diversity_signal: float = 0.0 + sentence_signal: float = 0.0 + paragraph_signal: float = 0.0 + body_found_signal: float = 0.0 + metadata_signal: float = 0.0 + + def as_dict(self) -> dict[str, float]: + return { + "word_count": self.word_count_signal, + "diversity": self.diversity_signal, + "sentence": self.sentence_signal, + "paragraph": self.paragraph_signal, + "body_found": self.body_found_signal, + "metadata": self.metadata_signal, + } + + +@dataclass +class CompanyMention: + """A detected company mention in parsed text. + + Requirements: 1.3, 4.1 + """ + company_id: str + ticker: str + mention_type: str # ticker, legal_name, alias, brand + confidence: float + match_count: int = 1 + + +@dataclass +class ParsedDocument: + """Result of HTML-to-text parsing pipeline.""" + body_text: str = "" + title: str = "" + author: str = "" + publisher: str = "" + published_at: str | None = None + canonical_url: str | None = None + language: str = "en" + description: str = "" + document_type: str = "article" + outbound_links: list[str] = field(default_factory=list) + tags: list[str] = field(default_factory=list) + mentioned_companies: list[CompanyMention] = field(default_factory=list) + quality_score: float = 0.0 + confidence: str = "low" + word_count: int = 0 + quality_signals: QualitySignals = field(default_factory=QualitySignals) + low_quality_flag: bool = False + quality_warnings: list[str] = field(default_factory=list) + + +def _attr_str(tag: Tag, attr: str) -> str: + """Safely get a tag attribute as a joined string.""" + val = tag.get(attr, "") + if isinstance(val, list): + return " ".join(val) + return str(val) if val else "" + + +def _is_boilerplate_container(tag: Tag) -> bool: + """Check if a tag looks like a boilerplate container by class/id.""" + cls = _attr_str(tag, "class").lower() + tag_id = _attr_str(tag, "id").lower() + combined = f"{cls} {tag_id}" + return any(sig in combined for sig in BOILERPLATE_SIGNALS) + + +def _strip_boilerplate_tags(soup: BeautifulSoup) -> None: + """Remove known non-content tags and boilerplate containers in-place.""" + for tag_name in STRIP_TAGS: + for tag in soup.find_all(tag_name): + tag.decompose() + + for tag in soup.find_all(True): + if _is_boilerplate_container(tag): + tag.decompose() + + +def _reduce_boilerplate_text(text: str) -> str: + """Apply regex patterns to strip residual boilerplate from extracted text.""" + for pattern in BOILERPLATE_TEXT_PATTERNS: + text = pattern.sub("", text) + return text.strip() + + +def _text_density(tag: Tag) -> float: + """Compute text density for a tag: ratio of text length to total markup length. + + Higher density means more actual text relative to HTML structure, + which is a strong signal for content blocks vs boilerplate. + + Requirements: 4.2 + """ + markup_len = len(str(tag)) + if markup_len == 0: + return 0.0 + text_len = len(tag.get_text(strip=True)) + return text_len / markup_len + + +def _link_density(tag: Tag) -> float: + """Compute link density: ratio of text inside <a> tags to total text. + + High link density signals navigation/boilerplate blocks (menus, sidebars). + Low link density signals content paragraphs. + + Requirements: 4.2 + """ + total_text = len(tag.get_text(strip=True)) + if total_text == 0: + return 1.0 + link_text = sum(len(a.get_text(strip=True)) for a in tag.find_all("a")) + return link_text / total_text + + +def _block_score(tag: Tag) -> float: + """Score a block element as a body candidate using text density heuristics. + + Combines text density, link density, paragraph count, and word count + into a composite score. Higher is more likely to be the article body. + + Requirements: 4.2 + """ + text = tag.get_text(strip=True) + word_count = len(text.split()) + if word_count < _MIN_BLOCK_WORDS: + return 0.0 + + td = _text_density(tag) + ld = _link_density(tag) + p_count = len(tag.find_all("p")) + + # Base score from text density (0-1), penalized by link density + score = td * (1.0 - ld) + + # Bonus for paragraph-rich blocks (structured article content) + if p_count >= 2: + score += 0.1 * min(p_count, 10) + + # Bonus for word count (log-scaled to avoid runaway scores) + score += 0.05 * math.log(max(word_count, 1)) + + return score + + +def _find_article_body(soup: BeautifulSoup) -> Tag | None: + """Find the most likely article body element. + + First tries semantic selectors (article, [role=main], etc.). + If no semantic match, falls back to text-density scoring across + candidate block elements to find the content-richest container. + + Requirements: 4.2 + """ + # Priority 1: semantic selectors + for selector in ARTICLE_SELECTORS: + result = soup.select_one(selector) + if result: + text = result.get_text(strip=True) + if len(text.split()) >= _MIN_BLOCK_WORDS: + return result + + # Priority 2: text-density scoring on block-level containers + candidates: list[tuple[float, Tag]] = [] + for tag in soup.find_all(["div", "section", "td"]): + score = _block_score(tag) + if score > 0: + candidates.append((score, tag)) + + if candidates: + candidates.sort(key=lambda x: x[0], reverse=True) + return candidates[0][1] + + return None + + +def _collapse_whitespace(text: str) -> str: + """Collapse runs of blank lines into single separators.""" + lines = [line.strip() for line in text.splitlines()] + result: list[str] = [] + prev_blank = False + for line in lines: + if not line: + if not prev_blank: + result.append("") + prev_blank = True + else: + result.append(line) + prev_blank = False + return "\n".join(result).strip() + + +def _remove_short_orphan_lines(text: str, min_words: int = 3) -> str: + """Remove very short orphan lines that are likely UI fragments or captions. + + Lines shorter than min_words that don't end with sentence punctuation + are stripped. This catches leftover button labels, image captions, + and navigation fragments. + + Requirements: 4.2 + """ + lines = text.splitlines() + kept: list[str] = [] + for line in lines: + stripped = line.strip() + words = stripped.split() + if len(words) < min_words and not stripped.endswith((".", "!", "?", ":")): + continue + kept.append(line) + return "\n".join(kept) + + +def _detect_repeated_blocks(text: str, min_len: int = 40) -> str: + """Remove repeated text blocks that appear more than once. + + Template text (disclaimers, repeated footers) often appears verbatim + in multiple places. This strips exact duplicate blocks. + + Requirements: 4.2 + """ + lines = text.splitlines() + seen: dict[str, int] = {} + for line in lines: + stripped = line.strip() + if len(stripped) >= min_len: + seen[stripped] = seen.get(stripped, 0) + 1 + + duplicates = {k for k, v in seen.items() if v > 1} + if not duplicates: + return text + + kept: list[str] = [] + emitted: set[str] = set() + for line in lines: + stripped = line.strip() + if stripped in duplicates: + if stripped not in emitted: + kept.append(line) + emitted.add(stripped) + # Skip subsequent duplicates + else: + kept.append(line) + return "\n".join(kept) + + +def extract_body_text(html: str) -> str: + """Extract main body text from HTML with boilerplate removal. + + Pipeline: + 1. Strip non-content tags (script, style, nav, footer, etc.) + 2. Strip boilerplate containers by class/id signals + 3. Find article body via semantic selectors or text-density scoring + 4. Extract text from best candidate + 5. Remove residual boilerplate via regex patterns + 6. Remove short orphan lines (UI fragments) + 7. Detect and collapse repeated template blocks + 8. Collapse whitespace + + Requirements: 4.1, 4.2 + """ + soup = BeautifulSoup(html, "html.parser") + _strip_boilerplate_tags(soup) + + article = _find_article_body(soup) + if article: + raw_text = article.get_text(separator="\n", strip=True) + else: + body = soup.find("body") + raw_text = (body or soup).get_text(separator="\n", strip=True) + + # Multi-stage text cleaning + text = _reduce_boilerplate_text(raw_text) + text = _remove_short_orphan_lines(text) + text = _detect_repeated_blocks(text) + text = _collapse_whitespace(text) + return text + + +def extract_metadata(html: str, url: str = "") -> dict[str, str | None]: + """Extract document metadata from HTML head elements. + + Extracts title, author, publisher, published date, canonical URL, + language, description, and tags/keywords. + + Requirements: 4.1 + """ + soup = BeautifulSoup(html, "html.parser") + meta: dict[str, str | None] = {} + + # Title: og:title > <title> + og_title = soup.find("meta", property="og:title") + if og_title and og_title.get("content"): + content = og_title["content"] + meta["title"] = content.strip() if isinstance(content, str) else "" + elif soup.title and soup.title.string: + meta["title"] = soup.title.string.strip() + else: + meta["title"] = "" + + # Author + author_tag = soup.find("meta", attrs={"name": "author"}) + if author_tag and author_tag.get("content"): + content = author_tag["content"] + meta["author"] = content.strip() if isinstance(content, str) else "" + else: + meta["author"] = "" + + # Publisher: og:site_name > hostname + site_name = soup.find("meta", property="og:site_name") + if site_name and site_name.get("content"): + content = site_name["content"] + meta["publisher"] = content.strip() if isinstance(content, str) else "" + else: + meta["publisher"] = urlparse(url).hostname or "" if url else "" + + # Published date: article:published_time > JSON-LD datePublished + pub_time = soup.find("meta", property="article:published_time") + if pub_time and pub_time.get("content"): + content = pub_time["content"] + meta["published_at"] = content.strip() if isinstance(content, str) else None + else: + meta["published_at"] = _extract_jsonld_date(soup) + + # Canonical URL + canonical = soup.find("link", rel="canonical") + if canonical and canonical.get("href"): + meta["canonical_url"] = str(canonical["href"]) + else: + og_url = soup.find("meta", property="og:url") + if og_url and og_url.get("content"): + meta["canonical_url"] = str(og_url["content"]) + else: + meta["canonical_url"] = url or None + + # Language + html_tag = soup.find("html") + if html_tag and html_tag.get("lang"): + lang = html_tag["lang"] + meta["language"] = str(lang)[:5] if lang else "en" + else: + meta["language"] = "en" + + # Description + desc = soup.find("meta", property="og:description") or soup.find( + "meta", attrs={"name": "description"} + ) + if desc and desc.get("content"): + content = desc["content"] + meta["description"] = content.strip() if isinstance(content, str) else "" + else: + meta["description"] = "" + + # Tags / keywords + keywords = soup.find("meta", attrs={"name": "keywords"}) + if keywords and keywords.get("content"): + content = keywords["content"] + raw = content.strip() if isinstance(content, str) else "" + meta["tags"] = raw # comma-separated string + else: + meta["tags"] = "" + + return meta + + +def _extract_jsonld_date(soup: BeautifulSoup) -> str | None: + """Try to extract datePublished from JSON-LD script tags.""" + for script in soup.find_all("script", type="application/ld+json"): + if script.string and "datePublished" in script.string: + try: + ld = json.loads(script.string) + if isinstance(ld, dict) and "datePublished" in ld: + return str(ld["datePublished"]) + if isinstance(ld, list): + for item in ld: + if isinstance(item, dict) and "datePublished" in item: + return str(item["datePublished"]) + except (json.JSONDecodeError, TypeError): + pass + return None + + +def extract_outbound_links(html: str, base_url: str = "") -> list[str]: + """Extract outbound links from HTML, filtering out self-references. + + Requirements: 4.1 + """ + soup = BeautifulSoup(html, "html.parser") + base_host = urlparse(base_url).hostname or "" if base_url else "" + links: list[str] = [] + + for a_tag in soup.find_all("a", href=True): + href = str(a_tag["href"]).strip() + if not href or href.startswith("#") or href.startswith("javascript:"): + continue + parsed = urlparse(href) + # Only include absolute URLs that point to different hosts + if parsed.scheme in ("http", "https") and parsed.hostname: + if parsed.hostname != base_host: + links.append(href) + + # Dedupe while preserving order + seen: set[str] = set() + unique: list[str] = [] + for link in links: + if link not in seen: + seen.add(link) + unique.append(link) + return unique + + +def _count_sentences(text: str) -> int: + """Count approximate sentence count by terminal punctuation.""" + return len(re.findall(r"[.!?]+(?:\s|$)", text)) + + +def _count_paragraphs(text: str) -> int: + """Count non-empty paragraph blocks separated by blank lines.""" + blocks = re.split(r"\n\s*\n", text.strip()) + return sum(1 for b in blocks if len(b.strip().split()) >= 5) + + +def score_parse_quality( + text: str, + *, + body_found: bool = True, + has_title: bool = False, + has_author: bool = False, + has_publisher: bool = False, + has_published_at: bool = False, +) -> tuple[float, str, QualitySignals, list[str]]: + """Score parse quality using multiple content and metadata signals. + + Returns (score, confidence_label, signals, warnings). + + Signals considered: + - word_count_signal: length of extracted text + - diversity_signal: vocabulary richness (unique/total words) + - sentence_signal: presence of proper sentence structure + - paragraph_signal: multi-paragraph structure + - body_found_signal: whether a semantic article body was located + - metadata_signal: presence of title, author, publisher, date + + Requirements: 4.3 + """ + warnings: list[str] = [] + words = text.split() + word_count = len(words) + + # --- word count signal --- + if word_count < 20: + wc_sig = 0.1 + warnings.append("very_short_text") + elif word_count < 50: + wc_sig = 0.3 + warnings.append("short_text") + elif word_count < 150: + wc_sig = 0.6 + elif word_count < 300: + wc_sig = 0.8 + else: + wc_sig = 1.0 + + # --- diversity signal --- + if word_count > 0: + unique = len(set(w.lower() for w in words)) + diversity = unique / word_count + else: + diversity = 0.0 + if diversity < 0.2: + div_sig = 0.2 + if word_count >= 20: + warnings.append("low_vocabulary_diversity") + elif diversity < 0.4: + div_sig = 0.5 + else: + div_sig = 1.0 + + # --- sentence signal --- + sentence_count = _count_sentences(text) + if sentence_count == 0: + sent_sig = 0.1 + if word_count >= 20: + warnings.append("no_sentence_structure") + elif sentence_count < 3: + sent_sig = 0.5 + else: + sent_sig = 1.0 + + # --- paragraph signal --- + para_count = _count_paragraphs(text) + if para_count == 0: + para_sig = 0.2 + elif para_count == 1: + para_sig = 0.5 + else: + para_sig = 1.0 + + # --- body found signal --- + body_sig = 1.0 if body_found else 0.3 + if not body_found: + warnings.append("no_article_body_found") + + # --- metadata signal --- + meta_hits = sum([has_title, has_author, has_publisher, has_published_at]) + meta_sig = meta_hits / 4.0 + + signals = QualitySignals( + word_count_signal=wc_sig, + diversity_signal=div_sig, + sentence_signal=sent_sig, + paragraph_signal=para_sig, + body_found_signal=body_sig, + metadata_signal=meta_sig, + ) + + # Weighted composite score + score = ( + 0.30 * wc_sig + + 0.15 * div_sig + + 0.15 * sent_sig + + 0.10 * para_sig + + 0.20 * body_sig + + 0.10 * meta_sig + ) + score = round(min(score, 0.95), 2) + + # Confidence label + if score < 0.35: + confidence = "low" + elif score < 0.65: + confidence = "medium" + else: + confidence = "high" + + return score, confidence, signals, warnings + + +def score_quality(text: str) -> tuple[float, str]: + """Score parse quality based on extracted text characteristics. + + Returns (score, confidence_label) where confidence is low/medium/high. + Thin wrapper around score_parse_quality for backward compatibility. + + Requirements: 4.3 + """ + score, confidence, _signals, _warnings = score_parse_quality(text) + return score, confidence + + +def infer_document_type(html: str, url: str = "") -> str: + """Infer document type from URL patterns and HTML content. + + Requirements: 4.1 + """ + url_lower = url.lower() + if any(kw in url_lower for kw in ["sec.gov", "edgar", "filing", "10-k", "10-q", "8-k"]): + return "filing" + if any(kw in url_lower for kw in ["transcript", "earnings-call", "earnings_call"]): + return "transcript" + if any(kw in url_lower for kw in ["press-release", "press_release", "newsroom"]): + return "press_release" + # html reserved for future content-based inference + _ = html + return "article" + + +def parse_html(html: str, url: str = "", aliases: list[dict[str, str]] | None = None) -> ParsedDocument: + """Full HTML-to-text parsing pipeline. + + Combines body extraction, metadata extraction, link extraction, + quality scoring, document type inference, and company mention + detection into a single result. + + Requirements: 1.3, 4.1, 4.2, 4.3 + """ + soup = BeautifulSoup(html, "html.parser") + _strip_boilerplate_tags(soup) + + article = _find_article_body(soup) + body_found = article is not None + if article: + raw_text = article.get_text(separator="\n", strip=True) + else: + body = soup.find("body") + raw_text = (body or soup).get_text(separator="\n", strip=True) + + # Multi-stage text cleaning + text = _reduce_boilerplate_text(raw_text) + text = _remove_short_orphan_lines(text) + text = _detect_repeated_blocks(text) + text = _collapse_whitespace(text) + + metadata = extract_metadata(html, url) + outbound_links = extract_outbound_links(html, url) + doc_type = infer_document_type(html, url) + word_count = len(text.split()) + + tags_raw = metadata.get("tags", "") or "" + tags = [t.strip() for t in tags_raw.split(",") if t.strip()] if tags_raw else [] + + # Rich quality scoring with all available signals + quality, confidence, signals, warnings = score_parse_quality( + text, + body_found=body_found, + has_title=bool(metadata.get("title")), + has_author=bool(metadata.get("author")), + has_publisher=bool(metadata.get("publisher")), + has_published_at=bool(metadata.get("published_at")), + ) + + low_quality_flag = confidence == "low" + + # Company mention detection + mentioned: list[CompanyMention] = [] + if aliases and text: + # Search title + body for mentions + search_text = f"{metadata.get('title', '')} {text}" + raw_mentions = detect_company_mentions(search_text, aliases) + for m in raw_mentions: + mentioned.append(CompanyMention( + company_id=str(m["company_id"]), + ticker=str(m["ticker"]), + mention_type=str(m["mention_type"]), + confidence=float(m["confidence"]), + match_count=int(m["match_count"]), + )) + + return ParsedDocument( + body_text=text, + title=metadata.get("title", "") or "", + author=metadata.get("author", "") or "", + publisher=metadata.get("publisher", "") or "", + published_at=metadata.get("published_at"), + canonical_url=metadata.get("canonical_url"), + language=metadata.get("language", "en") or "en", + description=metadata.get("description", "") or "", + document_type=doc_type, + outbound_links=outbound_links, + tags=tags, + mentioned_companies=mentioned, + quality_score=quality, + confidence=confidence, + word_count=word_count, + quality_signals=signals, + low_quality_flag=low_quality_flag, + quality_warnings=warnings, + ) + + + +@dataclass +class AliasEntry: + """A company alias used for mention detection.""" + company_id: str + alias: str + alias_type: str = "alias" + ticker: str = "" + + +# Confidence by alias type — tickers are most precise, brands least +_CONFIDENCE_BY_TYPE: dict[str, float] = { + "ticker": 0.9, + "legal_name": 0.85, + "alias": 0.7, + "brand": 0.6, +} + + +def _build_alias_entries(aliases: list[dict[str, str]]) -> list[AliasEntry]: + """Convert raw alias dicts to typed AliasEntry objects.""" + entries: list[AliasEntry] = [] + for a in aliases: + alias_val = a.get("alias", "") + if not alias_val: + continue + entries.append(AliasEntry( + company_id=a.get("company_id", ""), + alias=alias_val, + alias_type=a.get("alias_type", "alias"), + ticker=a.get("ticker", ""), + )) + return entries + + +def _count_matches(text: str, pattern: re.Pattern[str]) -> int: + """Count non-overlapping matches of pattern in text.""" + return len(pattern.findall(text)) + + +def detect_company_mentions( + text: str, + aliases: list[dict[str, str]], +) -> list[dict[str, str | float | int]]: + """Detect company mentions using ticker, alias, and name matching. + + Matching strategy by alias length: + - 1-2 chars: case-sensitive word-boundary match (avoids "A" matching "a") + - 3-4 chars: case-insensitive word-boundary match (standard tickers) + - 5+ chars: case-insensitive substring match (company names, brands) + + Confidence varies by alias_type: ticker > legal_name > alias > brand. + Multiple alias hits for the same company are deduplicated, keeping the + highest-confidence match and summing match counts. + + Requirements: 1.3, 4.1 + """ + if not text: + return [] + + entries = _build_alias_entries(aliases) + text_upper = text.upper() + + # Track best match per company: company_id -> (confidence, ticker, mention_type, count) + best: dict[str, tuple[float, str, str, int]] = {} + + for entry in entries: + alias = entry.alias + alias_type = entry.alias_type + base_confidence = _CONFIDENCE_BY_TYPE.get(alias_type, 0.7) + + match_count = 0 + + if len(alias) <= 2: + # Very short: case-sensitive word boundary + pattern = re.compile(r"\b" + re.escape(alias) + r"\b") + match_count = _count_matches(text, pattern) + elif len(alias) <= 4: + # Standard ticker length: case-insensitive word boundary + pattern = re.compile(r"\b" + re.escape(alias.upper()) + r"\b") + match_count = _count_matches(text_upper, pattern) + else: + # Longer names: case-insensitive substring + alias_up = alias.upper() + match_count = text_upper.count(alias_up) + + if match_count == 0: + continue + + cid = entry.company_id + existing = best.get(cid) + if existing is None: + best[cid] = (base_confidence, entry.ticker, alias_type, match_count) + else: + # Keep highest confidence, accumulate match count + prev_conf, prev_ticker, prev_type, prev_count = existing + if base_confidence > prev_conf: + best[cid] = (base_confidence, entry.ticker, alias_type, prev_count + match_count) + else: + best[cid] = (prev_conf, prev_ticker, prev_type, prev_count + match_count) + + mentions: list[dict[str, str | float | int]] = [] + for cid, (confidence, ticker, mention_type, count) in best.items(): + mentions.append({ + "company_id": cid, + "ticker": ticker, + "mention_type": mention_type, + "confidence": confidence, + "match_count": count, + }) + + return mentions diff --git a/services/parser/worker.py b/services/parser/worker.py index a4a3b93..b30efab 100644 --- a/services/parser/worker.py +++ b/services/parser/worker.py @@ -1,84 +1,41 @@ -"""Parser worker - HTML-to-text, boilerplate reduction, quality scoring.""" +"""Parser worker - HTML-to-text, boilerplate reduction, quality scoring. + +Uses BeautifulSoup-based parsing pipeline for structured HTML extraction, +metadata extraction, outbound link extraction, and quality scoring. +Persists normalized text and structured parser output to MinIO, +and updates document metadata in PostgreSQL. + +Requirements: 4.1, 4.2, 4.3, 9.1, 9.2 +""" import asyncio -import io import json import logging -import re -from datetime import datetime -from typing import List, Optional, Tuple +import time +from datetime import datetime, timezone +from typing import Any, Optional import asyncpg import httpx import redis.asyncio as aioredis from minio import Minio +from services.parser.html_parser import ParsedDocument, detect_company_mentions, parse_html from services.shared.config import load_config from services.shared.db import get_minio, get_pg_pool, get_redis +from services.shared.logging import Span, extract_trace_context, inject_trace_context, new_trace_id, set_trace_context, setup_logging +from services.shared.metrics import ( + ACTIVE_JOBS, + PARSE_DURATION, + PARSE_JOBS_TOTAL, + PARSE_LOW_QUALITY_TOTAL, + PARSE_QUALITY_SCORE, +) +from services.shared.metadata import update_document_parse_results from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_PARSING, queue_key +from services.shared.storage import upload_normalized_text, upload_parser_output -logging.basicConfig(level=logging.INFO) logger = logging.getLogger("parser_worker") -# Simple boilerplate patterns to strip -BOILERPLATE_PATTERNS = [ - re.compile(r"(?i)subscribe to our newsletter.*?(?:\n|$)"), - re.compile(r"(?i)click here to read more.*?(?:\n|$)"), - re.compile(r"(?i)advertisement\s*\n"), - re.compile(r"(?i)copyright ©.*?(?:\n|$)"), - re.compile(r"(?i)all rights reserved.*?(?:\n|$)"), - re.compile(r"(?i)terms of (use|service).*?(?:\n|$)"), - re.compile(r"(?i)privacy policy.*?(?:\n|$)"), - re.compile(r"\s*\[.*?ad.*?\]\s*", re.IGNORECASE), -] - - -def strip_html_tags(html: str) -> str: - """Basic HTML tag removal.""" - text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.DOTALL | re.IGNORECASE) - text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.DOTALL | re.IGNORECASE) - text = re.sub(r"<[^>]+>", " ", text) - text = re.sub(r" ", " ", text) - text = re.sub(r"&", "&", text) - text = re.sub(r"<", "<", text) - text = re.sub(r">", ">", text) - text = re.sub(r"&#\d+;", "", text) - text = re.sub(r"\s+", " ", text).strip() - return text - - -def reduce_boilerplate(text: str) -> str: - for pattern in BOILERPLATE_PATTERNS: - text = pattern.sub("", text) - return text.strip() - - -def score_quality(text: str) -> Tuple[float, str]: - """Score parse quality. Returns (score, confidence_label).""" - word_count = len(text.split()) - if word_count < 20: - return 0.1, "low" - if word_count < 50: - return 0.3, "low" - if word_count < 150: - return 0.6, "medium" - return 0.85, "high" - - -def detect_company_mentions(text: str, aliases: List[dict]) -> List[dict]: - """Detect company mentions using ticker, alias, and name matching.""" - mentions = [] - text_upper = text.upper() - for alias_info in aliases: - alias = alias_info["alias"] - if alias.upper() in text_upper: - mentions.append({ - "company_id": alias_info["company_id"], - "ticker": alias_info.get("ticker", ""), - "mention_type": alias_info.get("alias_type", "alias"), - "confidence": 0.7, - }) - return mentions - async def fetch_html(url: str) -> Optional[str]: """Fetch article HTML for scraping.""" @@ -94,48 +51,65 @@ async def fetch_html(url: str) -> Optional[str]: return None +def build_parser_output_json(parsed: ParsedDocument, mentions: list[dict[str, Any]]) -> dict[str, Any]: + """Build a structured JSON dict from ParsedDocument and detected mentions. + + This captures the full parser output for audit and downstream use: + metadata, quality signals, warnings, outbound links, tags, and mentions. + """ + return { + "title": parsed.title, + "author": parsed.author, + "publisher": parsed.publisher, + "published_at": parsed.published_at, + "canonical_url": parsed.canonical_url, + "language": parsed.language, + "description": parsed.description, + "document_type": parsed.document_type, + "word_count": parsed.word_count, + "outbound_links": parsed.outbound_links, + "tags": parsed.tags, + "quality_score": parsed.quality_score, + "confidence": parsed.confidence, + "low_quality_flag": parsed.low_quality_flag, + "quality_warnings": parsed.quality_warnings, + "quality_signals": parsed.quality_signals.as_dict(), + "mentioned_companies": mentions, + } + + async def process_job( - job: dict, + job: dict[str, Any], pool: asyncpg.Pool, rds: aioredis.Redis, minio_client: Minio, -): +) -> None: doc_id = job["document_id"] ticker = job["ticker"] url = job.get("url", "") + now = datetime.now(timezone.utc) + _parse_start = time.monotonic() + + set_trace_context(trace_id=job.get("_trace_id") or new_trace_id()) # Fetch HTML if we have a URL html = await fetch_html(url) if url else None if html: - # Store raw HTML - html_bytes = html.encode("utf-8") - now = datetime.utcnow() - html_path = f"scrape/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/raw.html" - minio_client.put_object( - "stonks-raw-news", html_path, io.BytesIO(html_bytes), len(html_bytes), - content_type="text/html", - ) - - # Parse - text = strip_html_tags(html) - text = reduce_boilerplate(text) + # Parse using BeautifulSoup pipeline + parsed = parse_html(html, url) else: - text = "" + parsed = ParsedDocument() - quality_score, confidence = score_quality(text) + text = parsed.body_text - # Store normalized text + # Upload normalized text to MinIO + norm_ref: str | None = None if text: - text_bytes = text.encode("utf-8") - now = datetime.utcnow() - norm_path = f"parsed/{ticker}/{now.year}/{now.month:02d}/{now.day:02d}/{doc_id}/normalized.txt" - minio_client.put_object( - "stonks-normalized", norm_path, io.BytesIO(text_bytes), len(text_bytes), - content_type="text/plain", + norm_ref = upload_normalized_text( + minio_client, ticker, doc_id, + text.encode("utf-8"), timestamp=now, ) - else: - norm_path = None # Detect company mentions aliases = await pool.fetch( @@ -150,14 +124,24 @@ async def process_job( ) mentions = detect_company_mentions(text, [dict(a) for a in aliases]) if text else [] - # Update document - status = "parsed" if confidence != "low" else "low_quality" - await pool.execute( - """UPDATE documents SET - normalized_storage_ref=$2, parse_quality_score=$3, parse_confidence=$4, status=$5, updated_at=NOW() - WHERE id=$1""", - doc_id, f"s3://stonks-normalized/{norm_path}" if norm_path else None, - quality_score, confidence, status, + # Build and upload structured parser output JSON + output_json = build_parser_output_json(parsed, mentions) + output_bytes = json.dumps(output_json, default=str, indent=2).encode("utf-8") + parser_output_ref = upload_parser_output( + minio_client, ticker, doc_id, + output_bytes, timestamp=now, + ) + + # Update document in PostgreSQL + status = "parsed" if parsed.confidence != "low" else "low_quality" + await update_document_parse_results( + pool, + document_id=doc_id, + normalized_storage_ref=norm_ref, + parser_output_ref=parser_output_ref, + parse_quality_score=parsed.quality_score, + parse_confidence=parsed.confidence, + status=status, ) # Insert company mentions @@ -169,19 +153,36 @@ async def process_job( ) # Only enqueue for extraction if quality is acceptable - if confidence != "low": - await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps({ + if parsed.confidence != "low": + await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps(inject_trace_context({ "document_id": doc_id, "ticker": ticker, - "normalized_text": text[:8000], # Truncate for prompt - })) - logger.info(f"Parsed doc {doc_id} for {ticker}: quality={quality_score:.2f}, confidence={confidence}") + "normalized_text": text[:8000], + }))) + PARSE_JOBS_TOTAL.labels(status="parsed").inc() + PARSE_QUALITY_SCORE.observe(parsed.quality_score) + PARSE_DURATION.observe(time.monotonic() - _parse_start) + logger.info( + "Parsed doc %s for %s: quality=%.2f, confidence=%s", + doc_id, ticker, parsed.quality_score, parsed.confidence, + extra={"ticker": ticker, "document_id": doc_id}, + ) else: - logger.warning(f"Low quality parse for doc {doc_id}, skipping extraction") + PARSE_JOBS_TOTAL.labels(status="low_quality").inc() + PARSE_LOW_QUALITY_TOTAL.inc() + PARSE_QUALITY_SCORE.observe(parsed.quality_score) + PARSE_DURATION.observe(time.monotonic() - _parse_start) + logger.warning( + "Low quality parse for doc %s, skipping extraction", + doc_id, + extra={"ticker": ticker, "document_id": doc_id}, + ) -async def main(): +async def main() -> None: config = load_config() + setup_logging("parser_worker", level=config.log_level, json_output=config.json_logs) + pool = await get_pg_pool(config) rds = get_redis(config) minio_client = get_minio(config) @@ -197,7 +198,7 @@ async def main(): try: await process_job(job, pool, rds, minio_client) except Exception as e: - logger.error(f"Parse error: {e}") + logger.error("Parse error: %s", e, exc_info=True) else: await asyncio.sleep(2) finally: diff --git a/services/recommendation/eligibility.py b/services/recommendation/eligibility.py new file mode 100644 index 0000000..cb83cd9 --- /dev/null +++ b/services/recommendation/eligibility.py @@ -0,0 +1,354 @@ +"""Deterministic recommendation eligibility logic. + +Evaluates trend summaries against configurable thresholds to decide: +- Whether a recommendation should be generated at all +- What action type (buy/sell/hold/watch) is appropriate +- What execution mode (informational/paper_eligible/live_eligible) is allowed +- Position sizing guidance based on portfolio rules + +All decisions are rule-based with no model involvement. The LLM is only +used downstream for optional thesis wording (a separate task). + +Requirements: 7.1, 7.2, 7.3, 7.4 +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum + +from services.shared.schemas import ( + ActionType, + PositionSizing, + RecommendationMode, + TrendDirection, + TrendSummary, +) + + +class RejectionReason(str, Enum): + """Why a trend summary was deemed ineligible for a recommendation.""" + + LOW_CONFIDENCE = "low_confidence" + LOW_TREND_STRENGTH = "low_trend_strength" + HIGH_CONTRADICTION = "high_contradiction" + INSUFFICIENT_EVIDENCE = "insufficient_evidence" + NEUTRAL_DIRECTION = "neutral_direction" + + +@dataclass(frozen=True) +class EligibilityConfig: + """Tunable thresholds for recommendation eligibility. + + All thresholds are deterministic — no model inference involved. + """ + + # --- Gate thresholds (below these → no recommendation) --- + min_confidence: float = 0.35 + min_trend_strength: float = 0.10 + max_contradiction_score: float = 0.60 + min_evidence_count: int = 2 # combined supporting + opposing + + # --- Action mapping thresholds --- + # Trend strength above this → buy/sell; below → hold/watch + action_strength_threshold: float = 0.25 + # Confidence above this → hold (rather than watch) for weak signals + hold_confidence_threshold: float = 0.50 + + # --- Mode escalation thresholds --- + # Confidence required for paper_eligible (below → informational) + paper_confidence_threshold: float = 0.50 + # Confidence required for live_eligible (below → paper at most) + live_confidence_threshold: float = 0.70 + # Contradiction must be below this for live eligibility + live_max_contradiction: float = 0.25 + # Minimum evidence count for live eligibility + live_min_evidence: int = 5 + + # --- Position sizing rules (Requirement 7.3) --- + # Base portfolio allocation percentage + base_portfolio_pct: float = 0.02 + # Maximum portfolio allocation percentage + max_portfolio_pct: float = 0.05 + # Base max loss percentage + base_max_loss_pct: float = 0.005 + # Maximum max loss percentage + max_max_loss_pct: float = 0.01 + # Confidence scaling: higher confidence → larger position (linear) + confidence_sizing_weight: float = 0.5 + # Contradiction penalty: higher contradiction → smaller position + contradiction_sizing_penalty: float = 0.3 + + +DEFAULT_ELIGIBILITY_CONFIG = EligibilityConfig() + + +@dataclass +class EligibilityResult: + """Output of the deterministic eligibility evaluation. + + Captures the decision, the reasoning, and all inputs used so the + full decision trace is reproducible (Requirement 8.3). + """ + + eligible: bool + action: ActionType + mode: RecommendationMode + position_sizing: PositionSizing + rejection_reasons: list[RejectionReason] = field(default_factory=list) + time_horizon: str = "" + invalidation_conditions: list[str] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Gate checks +# --------------------------------------------------------------------------- + + +def _check_gates( + summary: TrendSummary, + config: EligibilityConfig, +) -> list[RejectionReason]: + """Apply hard gate checks. Returns a list of rejection reasons (empty = pass).""" + reasons: list[RejectionReason] = [] + + if summary.confidence < config.min_confidence: + reasons.append(RejectionReason.LOW_CONFIDENCE) + + if summary.trend_strength < config.min_trend_strength: + reasons.append(RejectionReason.LOW_TREND_STRENGTH) + + if summary.contradiction_score > config.max_contradiction_score: + reasons.append(RejectionReason.HIGH_CONTRADICTION) + + evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence) + if evidence_count < config.min_evidence_count: + reasons.append(RejectionReason.INSUFFICIENT_EVIDENCE) + + if summary.trend_direction == TrendDirection.NEUTRAL: + reasons.append(RejectionReason.NEUTRAL_DIRECTION) + + return reasons + + +# --------------------------------------------------------------------------- +# Action mapping +# --------------------------------------------------------------------------- + + +def _determine_action( + summary: TrendSummary, + config: EligibilityConfig, +) -> ActionType: + """Map trend direction and strength to an action type. + + Strong bullish → BUY, strong bearish → SELL. + Weak but directional → HOLD if confidence is decent, else WATCH. + Mixed → WATCH. + """ + direction = summary.trend_direction + strength = summary.trend_strength + + if direction == TrendDirection.MIXED: + return ActionType.WATCH + + if direction == TrendDirection.NEUTRAL: + return ActionType.WATCH + + strong_signal = strength >= config.action_strength_threshold + + if direction == TrendDirection.BULLISH: + if strong_signal: + return ActionType.BUY + return ActionType.HOLD if summary.confidence >= config.hold_confidence_threshold else ActionType.WATCH + + if direction == TrendDirection.BEARISH: + if strong_signal: + return ActionType.SELL + return ActionType.HOLD if summary.confidence >= config.hold_confidence_threshold else ActionType.WATCH + + return ActionType.WATCH + + +# --------------------------------------------------------------------------- +# Mode escalation +# --------------------------------------------------------------------------- + + +def _determine_mode( + summary: TrendSummary, + action: ActionType, + config: EligibilityConfig, +) -> RecommendationMode: + """Determine the highest execution mode allowed. + + WATCH and HOLD actions are always informational — they don't trigger trades. + BUY/SELL can escalate to paper_eligible or live_eligible based on + confidence, contradiction, and evidence thresholds. + """ + if action in (ActionType.WATCH, ActionType.HOLD): + return RecommendationMode.INFORMATIONAL + + evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence) + + # Check live eligibility first (strictest) + if ( + summary.confidence >= config.live_confidence_threshold + and summary.contradiction_score <= config.live_max_contradiction + and evidence_count >= config.live_min_evidence + ): + return RecommendationMode.LIVE_ELIGIBLE + + # Check paper eligibility + if summary.confidence >= config.paper_confidence_threshold: + return RecommendationMode.PAPER_ELIGIBLE + + return RecommendationMode.INFORMATIONAL + + +# --------------------------------------------------------------------------- +# Position sizing (Requirement 7.3) +# --------------------------------------------------------------------------- + + +def _compute_position_sizing( + summary: TrendSummary, + config: EligibilityConfig, +) -> PositionSizing: + """Compute position sizing guidance from portfolio rules and signal quality. + + Higher confidence → larger allocation (up to max). + Higher contradiction → smaller allocation (penalty). + """ + # Start from base allocation + confidence_scale = config.base_portfolio_pct + ( + config.confidence_sizing_weight + * summary.confidence + * (config.max_portfolio_pct - config.base_portfolio_pct) + ) + + # Apply contradiction penalty + contradiction_penalty = config.contradiction_sizing_penalty * summary.contradiction_score + portfolio_pct = confidence_scale * (1.0 - contradiction_penalty) + + # Clamp to bounds + portfolio_pct = max(config.base_portfolio_pct * 0.5, min(portfolio_pct, config.max_portfolio_pct)) + + # Max loss scales similarly + loss_scale = config.base_max_loss_pct + ( + config.confidence_sizing_weight + * summary.confidence + * (config.max_max_loss_pct - config.base_max_loss_pct) + ) + max_loss_pct = loss_scale * (1.0 - contradiction_penalty) + max_loss_pct = max(config.base_max_loss_pct * 0.5, min(max_loss_pct, config.max_max_loss_pct)) + + return PositionSizing( + portfolio_pct=round(portfolio_pct, 6), + max_loss_pct=round(max_loss_pct, 6), + ) + + +# --------------------------------------------------------------------------- +# Time horizon mapping +# --------------------------------------------------------------------------- + +_WINDOW_TO_HORIZON: dict[str, str] = { + "intraday": "intraday", + "1d": "swing_1d_3d", + "7d": "swing_1d_10d", + "30d": "position_10d_30d", + "90d": "position_30d_90d", +} + + +def _map_time_horizon(window: str) -> str: + """Map a trend window to a human-readable time horizon label.""" + return _WINDOW_TO_HORIZON.get(window, f"window_{window}") + + +# --------------------------------------------------------------------------- +# Invalidation conditions +# --------------------------------------------------------------------------- + + +def _derive_invalidation_conditions( + summary: TrendSummary, + action: ActionType, +) -> list[str]: + """Generate deterministic invalidation conditions for the recommendation. + + These describe when the recommendation should be considered stale or wrong. + """ + conditions: list[str] = [] + + if action == ActionType.BUY: + conditions.append( + f"Trend direction for {summary.entity_id} reverses to bearish" + ) + elif action == ActionType.SELL: + conditions.append( + f"Trend direction for {summary.entity_id} reverses to bullish" + ) + + if summary.contradiction_score > 0.0: + conditions.append( + f"Contradiction score exceeds 0.60 (currently {summary.contradiction_score:.2f})" + ) + + if summary.confidence > 0.0: + conditions.append( + f"Confidence drops below {summary.confidence * 0.7:.2f}" + ) + + if summary.material_risks: + conditions.append( + f"Material risk materialises: {summary.material_risks[0]}" + ) + + return conditions + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + + +def evaluate_eligibility( + summary: TrendSummary, + config: EligibilityConfig = DEFAULT_ELIGIBILITY_CONFIG, +) -> EligibilityResult: + """Evaluate a trend summary for recommendation eligibility. + + This is the single deterministic entry point. It: + 1. Applies gate checks (confidence, strength, contradiction, evidence) + 2. Maps trend direction + strength to an action type + 3. Determines the highest allowed execution mode + 4. Computes position sizing from portfolio rules + 5. Derives invalidation conditions + + Returns an EligibilityResult with the full decision trace. + """ + rejection_reasons = _check_gates(summary, config) + + # Even if rejected, we still compute action/mode for the trace + action = _determine_action(summary, config) + mode = _determine_mode(summary, action, config) + sizing = _compute_position_sizing(summary, config) + horizon = _map_time_horizon(summary.window.value) + invalidation = _derive_invalidation_conditions(summary, action) + + eligible = len(rejection_reasons) == 0 + + # If not eligible, force mode to informational (Requirement 7.4) + if not eligible: + mode = RecommendationMode.INFORMATIONAL + + return EligibilityResult( + eligible=eligible, + action=action, + mode=mode, + position_sizing=sizing, + rejection_reasons=rejection_reasons, + time_horizon=horizon, + invalidation_conditions=invalidation, + ) diff --git a/services/recommendation/main.py b/services/recommendation/main.py new file mode 100644 index 0000000..84f8f9f --- /dev/null +++ b/services/recommendation/main.py @@ -0,0 +1,71 @@ +"""Recommendation worker entrypoint - polls Redis for recommendation jobs.""" +from __future__ import annotations + +import asyncio +import json +import logging + +import asyncpg +from minio import Minio + +from services.recommendation.worker import generate_recommendation +from services.shared.config import load_config +from services.shared.logging import setup_logging +from services.shared.redis_keys import QUEUE_RECOMMENDATION, queue_key + +logger = logging.getLogger("recommendation_main") + + +async def main() -> None: + config = load_config() + setup_logging("recommendation", level=config.log_level, json_output=config.json_logs) + + pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) + minio_client = Minio( + config.minio.endpoint, + access_key=config.minio.access_key, + secret_key=config.minio.secret_key, + secure=config.minio.secure, + ) + + import redis.asyncio as aioredis + + redis_client = aioredis.from_url(config.redis.url) + queue = queue_key(QUEUE_RECOMMENDATION) + logger.info("Recommendation worker started, polling %s", queue) + + try: + while True: + raw = await redis_client.lpop(queue) + if raw is None: + await asyncio.sleep(1) + continue + + payload = raw + job = json.loads(payload) + ticker = job.get("ticker", "") + window = job.get("window", "7d") + + logger.info("Processing recommendation job for %s/%s", ticker, window) + + try: + rec = await generate_recommendation( + pool, ticker, window, + minio_client=minio_client, + ) + if rec: + logger.info( + "Recommendation generated for %s: %s %s", + ticker, rec.action.value, rec.mode.value, + ) + else: + logger.info("No recommendation generated for %s (no trend data)", ticker) + except Exception: + logger.exception("Recommendation failed for %s", ticker) + finally: + await pool.close() + await redis_client.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/services/recommendation/suppression.py b/services/recommendation/suppression.py new file mode 100644 index 0000000..d508e87 --- /dev/null +++ b/services/recommendation/suppression.py @@ -0,0 +1,241 @@ +"""Suppression logic for low-quality data or low confidence. + +Evaluates the quality of the underlying data feeding a trend summary +and suppresses automated trade eligibility when data quality is poor. +Suppressed recommendations are marked as informational only. + +This layer runs *before* the eligibility engine and acts as a pre-filter +on data quality. The eligibility engine handles signal-level thresholds +(confidence, strength, contradiction); this module handles data-level +quality concerns (stale evidence, low extraction quality, poor source +diversity, insufficient valid documents). + +Requirements: 7.4 +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + +from services.shared.schemas import TrendSummary + +logger = logging.getLogger(__name__) + + +class SuppressionReason(str, Enum): + """Why a recommendation was suppressed due to data quality.""" + + LOW_DATA_CONFIDENCE = "low_data_confidence" + STALE_EVIDENCE = "stale_evidence" + LOW_SOURCE_DIVERSITY = "low_source_diversity" + HIGH_EXTRACTION_FAILURE_RATE = "high_extraction_failure_rate" + INSUFFICIENT_VALID_DOCUMENTS = "insufficient_valid_documents" + + +@dataclass(frozen=True) +class SuppressionConfig: + """Tunable thresholds for data quality suppression. + + These thresholds focus on the quality of the *input data* rather + than the trend signal itself (which is handled by EligibilityConfig). + """ + + # Minimum average extraction confidence across evidence documents. + # Below this, the underlying data is too unreliable for trade decisions. + min_avg_extraction_confidence: float = 0.40 + + # Maximum age (hours) of the most recent evidence document. + # If the freshest evidence is older than this, the trend is stale. + max_evidence_staleness_hours: float = 168.0 # 7 days + + # Minimum number of distinct source types (e.g. news, filings, market) + # represented in the evidence. Low diversity means the signal may be + # driven by a single unreliable source class. + min_source_types: int = 1 + + # Maximum tolerable extraction failure rate (0-1). + # If more than this fraction of documents failed extraction, + # the data pipeline is unreliable for this ticker. + max_extraction_failure_rate: float = 0.50 + + # Minimum number of valid (non-failed) documents that contributed + # to the trend. Below this, there isn't enough data to act on. + min_valid_documents: int = 2 + + # Overall data quality confidence threshold. + # The computed data quality score must exceed this for the + # recommendation to be eligible for automated trading. + min_data_quality_score: float = 0.30 + + +DEFAULT_SUPPRESSION_CONFIG = SuppressionConfig() + + +@dataclass +class DataQualityContext: + """Quality metrics about the data underlying a trend summary. + + Populated by querying document and extraction metadata for the + ticker and window. When not available from the database, callers + can construct this from the trend summary itself. + """ + + total_documents: int = 0 + valid_documents: int = 0 + failed_documents: int = 0 + avg_extraction_confidence: float = 0.0 + newest_evidence_at: datetime | None = None + source_types: set[str] = field(default_factory=set) + + +@dataclass +class SuppressionResult: + """Output of the suppression evaluation.""" + + suppressed: bool + reasons: list[SuppressionReason] = field(default_factory=list) + data_quality_score: float = 0.0 + context: DataQualityContext | None = None + + +def build_quality_context_from_summary( + summary: TrendSummary, +) -> DataQualityContext: + """Build a minimal DataQualityContext from a TrendSummary. + + This is a fallback when full document-level quality metrics aren't + available. It uses the trend summary's evidence counts and confidence + as proxies. + """ + total = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence) + return DataQualityContext( + total_documents=total, + valid_documents=total, + failed_documents=0, + avg_extraction_confidence=summary.confidence, + newest_evidence_at=summary.generated_at, + source_types=set(), + ) + + +def _compute_data_quality_score( + ctx: DataQualityContext, + config: SuppressionConfig, + reference_time: datetime, +) -> float: + """Compute an overall data quality score from the context. + + Returns a value in [0, 1] where higher is better quality. + Components: + - Extraction confidence (40% weight) + - Evidence freshness (30% weight) + - Document coverage (30% weight) + """ + # Extraction confidence component + conf_component = min(ctx.avg_extraction_confidence / 0.8, 1.0) + + # Freshness component + if ctx.newest_evidence_at is not None: + if ctx.newest_evidence_at.tzinfo is None: + newest = ctx.newest_evidence_at.replace(tzinfo=timezone.utc) + else: + newest = ctx.newest_evidence_at + age_hours = (reference_time - newest).total_seconds() / 3600.0 + max_hours = config.max_evidence_staleness_hours + freshness_component = max(0.0, 1.0 - (age_hours / max_hours)) + else: + freshness_component = 0.0 + + # Document coverage component + if ctx.total_documents > 0: + valid_ratio = ctx.valid_documents / ctx.total_documents + count_factor = min(ctx.valid_documents / 10.0, 1.0) + coverage_component = valid_ratio * count_factor + else: + coverage_component = 0.0 + + score = (0.4 * conf_component) + (0.3 * freshness_component) + (0.3 * coverage_component) + return round(max(0.0, min(1.0, score)), 4) + + +def evaluate_suppression( + summary: TrendSummary, + quality_ctx: DataQualityContext | None = None, + config: SuppressionConfig = DEFAULT_SUPPRESSION_CONFIG, + reference_time: datetime | None = None, +) -> SuppressionResult: + """Evaluate whether a recommendation should be suppressed due to data quality. + + Checks multiple data quality dimensions and returns a SuppressionResult + indicating whether the recommendation should be suppressed and why. + + Args: + summary: The trend summary to evaluate. + quality_ctx: Data quality context. If None, a minimal context is + built from the trend summary itself. + config: Suppression thresholds. + reference_time: Reference time for staleness checks. + + Returns: + SuppressionResult with suppression decision and reasons. + """ + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + ctx = quality_ctx or build_quality_context_from_summary(summary) + reasons: list[SuppressionReason] = [] + + # Check average extraction confidence + if ctx.avg_extraction_confidence < config.min_avg_extraction_confidence: + reasons.append(SuppressionReason.LOW_DATA_CONFIDENCE) + + # Check evidence staleness + if ctx.newest_evidence_at is not None: + newest = ctx.newest_evidence_at + if newest.tzinfo is None: + newest = newest.replace(tzinfo=timezone.utc) + age_hours = (reference_time - newest).total_seconds() / 3600.0 + if age_hours > config.max_evidence_staleness_hours: + reasons.append(SuppressionReason.STALE_EVIDENCE) + elif ctx.total_documents > 0: + # Have documents but no timestamp — treat as stale + reasons.append(SuppressionReason.STALE_EVIDENCE) + + # Check source diversity + if len(ctx.source_types) < config.min_source_types and ctx.total_documents > 0: + reasons.append(SuppressionReason.LOW_SOURCE_DIVERSITY) + + # Check extraction failure rate + if ctx.total_documents > 0: + failure_rate = ctx.failed_documents / ctx.total_documents + if failure_rate > config.max_extraction_failure_rate: + reasons.append(SuppressionReason.HIGH_EXTRACTION_FAILURE_RATE) + + # Check minimum valid documents + if ctx.valid_documents < config.min_valid_documents: + reasons.append(SuppressionReason.INSUFFICIENT_VALID_DOCUMENTS) + + # Compute overall data quality score + quality_score = _compute_data_quality_score(ctx, config, reference_time) + + # If quality score is below threshold, add a general suppression reason + if quality_score < config.min_data_quality_score and SuppressionReason.LOW_DATA_CONFIDENCE not in reasons: + reasons.append(SuppressionReason.LOW_DATA_CONFIDENCE) + + suppressed = len(reasons) > 0 + + if suppressed: + logger.info( + "Recommendation suppressed for %s/%s: reasons=%s quality_score=%.3f", + summary.entity_id, summary.window.value, + [r.value for r in reasons], quality_score, + ) + + return SuppressionResult( + suppressed=suppressed, + reasons=reasons, + data_quality_score=quality_score, + context=ctx, + ) diff --git a/services/recommendation/thesis_llm.py b/services/recommendation/thesis_llm.py new file mode 100644 index 0000000..4716aa8 --- /dev/null +++ b/services/recommendation/thesis_llm.py @@ -0,0 +1,175 @@ +"""Optional LLM wording layer for thesis generation. + +Takes a deterministic thesis string (built from trend data and eligibility +rules) and rewrites it into natural, analyst-quality prose using a local +Ollama model. The deterministic thesis is always preserved as the fallback +and audit reference. + +This module is opt-in: callers must explicitly request LLM rewriting. +If the LLM call fails or is disabled, the original deterministic thesis +is returned unchanged. + +Requirements: 7.1, 7.2 +""" +from __future__ import annotations + +import logging +import time + +import httpx + +from services.shared.config import OllamaConfig +from services.shared.schemas import TrendSummary + +logger = logging.getLogger(__name__) + +THESIS_PROMPT_VERSION = "thesis-rewrite-v1" + +THESIS_SYSTEM_PROMPT = """\ +You are a concise financial analyst. You rewrite structured trade thesis \ +summaries into clear, professional prose suitable for an internal research note. + +STRICT RULES: +1. Do NOT add any information that is not present in the input. +2. Do NOT fabricate numbers, dates, company names, or analyst opinions. +3. Keep the rewrite under 150 words. +4. Preserve all factual claims, risk notes, and evidence counts from the input. +5. Use a neutral, professional tone. Avoid hype or marketing language. +6. Return ONLY the rewritten thesis text. No JSON, no markdown, no commentary.""" + + +def build_thesis_rewrite_prompt( + deterministic_thesis: str, + summary: TrendSummary, +) -> dict[str, str]: + """Build system and user prompts for thesis rewriting. + + Provides the model with the deterministic thesis and key trend + context so it can produce a natural-language version. + """ + context_parts = [ + f"Ticker: {summary.entity_id}", + f"Window: {summary.window.value}", + f"Direction: {summary.trend_direction.value}", + f"Strength: {summary.trend_strength:.2f}", + f"Confidence: {summary.confidence:.2f}", + f"Contradiction score: {summary.contradiction_score:.2f}", + ] + if summary.dominant_catalysts: + context_parts.append(f"Catalysts: {', '.join(summary.dominant_catalysts[:3])}") + if summary.material_risks: + context_parts.append(f"Risks: {'; '.join(summary.material_risks[:2])}") + + context_block = "\n".join(context_parts) + + user_prompt = f"""\ +Rewrite the following structured thesis into clear, professional analyst prose. + +--- STRUCTURED THESIS --- +{deterministic_thesis} +--- END STRUCTURED THESIS --- + +--- CONTEXT --- +{context_block} +--- END CONTEXT --- + +Return ONLY the rewritten thesis. No other text.""" + + return { + "system": THESIS_SYSTEM_PROMPT, + "user": user_prompt, + } + + +async def rewrite_thesis_with_llm( + deterministic_thesis: str, + summary: TrendSummary, + config: OllamaConfig, + http_client: httpx.AsyncClient | None = None, +) -> str: + """Rewrite a deterministic thesis using a local Ollama model. + + If the LLM call fails for any reason, returns the original + deterministic thesis unchanged. This ensures the LLM layer is + purely additive and never blocks recommendation generation. + + Args: + deterministic_thesis: The rule-based thesis string. + summary: The trend summary that produced the thesis. + config: Ollama connection and model configuration. + http_client: Optional shared HTTP client for connection reuse. + + Returns: + The LLM-rewritten thesis on success, or the original on failure. + """ + prompts = build_thesis_rewrite_prompt(deterministic_thesis, summary) + + owns_client = http_client is None + client = http_client or httpx.AsyncClient(timeout=config.timeout) + + try: + rewritten = await _call_ollama_thesis(client, config, prompts) + if rewritten: + logger.info( + "LLM thesis rewrite succeeded for %s (%d chars → %d chars)", + summary.entity_id, + len(deterministic_thesis), + len(rewritten), + ) + return rewritten + + logger.warning( + "LLM thesis rewrite returned empty for %s — using deterministic thesis", + summary.entity_id, + ) + return deterministic_thesis + except Exception: + logger.exception( + "LLM thesis rewrite failed for %s — using deterministic thesis", + summary.entity_id, + ) + return deterministic_thesis + finally: + if owns_client: + await client.aclose() + + +async def _call_ollama_thesis( + client: httpx.AsyncClient, + config: OllamaConfig, + prompts: dict[str, str], +) -> str: + """Make a single Ollama chat call for thesis rewriting. + + Returns the model's text response, or empty string on failure. + """ + start = time.monotonic() + + payload = { + "model": config.model, + "messages": [ + {"role": "system", "content": prompts["system"]}, + {"role": "user", "content": prompts["user"]}, + ], + "stream": False, + } + + resp = await client.post( + f"{config.base_url}/api/chat", + json=payload, + ) + _ = resp.raise_for_status() + + duration_ms = int((time.monotonic() - start) * 1000) + + body: dict[str, object] = resp.json() + msg = body.get("message") + content: str = msg.get("content", "") if isinstance(msg, dict) else "" + + logger.debug( + "Ollama thesis call completed in %dms, response length=%d", + duration_ms, + len(content), + ) + + return content.strip() diff --git a/services/recommendation/worker.py b/services/recommendation/worker.py index b133580..b332e8e 100644 --- a/services/recommendation/worker.py +++ b/services/recommendation/worker.py @@ -1 +1,721 @@ -"""Recommendation worker - generates explainable trade recommendations from trend data.""" +"""Recommendation worker - generates explainable trade recommendations from trend data. + +Fetches the latest trend summaries for a ticker, evaluates eligibility +using deterministic rules, builds Recommendation objects with thesis +and evidence citations, and persists them to PostgreSQL. + +Requirements: 7.1, 7.2, 7.3, 7.4 +""" +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone + +import asyncpg + +from services.recommendation.eligibility import ( + EligibilityConfig, + EligibilityResult, + evaluate_eligibility, +) +from services.recommendation.suppression import ( + DataQualityContext, + SuppressionConfig, + SuppressionResult, + evaluate_suppression, +) +from services.recommendation.thesis_llm import ( + THESIS_PROMPT_VERSION, + rewrite_thesis_with_llm, +) +from minio import Minio + +from services.lake_publisher.worker import publish_recommendation_facts +from services.shared.config import OllamaConfig +from services.shared.schemas import ( + ModelMetadata, + PositionSizing, + Recommendation, + RecommendationMode, + TrendDirection, + TrendSummary, + TrendWindow, +) +from services.shared.metrics import ( + RECOMMENDATION_CONFIDENCE, + RECOMMENDATION_GENERATED, + RECOMMENDATION_SUPPRESSED, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Fetch latest trend summary for a ticker + window +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Fetch data quality context for suppression checks +# --------------------------------------------------------------------------- + +_DATA_QUALITY_QUERY = """ +SELECT + COUNT(*) AS total_documents, + COUNT(*) FILTER (WHERE di.validation_status = 'valid') AS valid_documents, + COUNT(*) FILTER (WHERE di.validation_status = 'failed') AS failed_documents, + AVG(di.confidence) FILTER (WHERE di.validation_status = 'valid') AS avg_extraction_confidence, + MAX(d.published_at) AS newest_evidence_at, + ARRAY_AGG(DISTINCT s.source_class) FILTER (WHERE s.source_class IS NOT NULL) AS source_types +FROM documents d +JOIN document_intelligence di ON di.document_id = d.id +LEFT JOIN sources s ON d.source_id = s.id +WHERE d.id = ANY( + SELECT UNNEST( + COALESCE(tw.top_supporting_evidence, '[]'::jsonb) + || COALESCE(tw.top_opposing_evidence, '[]'::jsonb) + )::uuid + FROM trend_windows tw + WHERE tw.entity_id = $1 AND tw.window = $2 + ORDER BY tw.generated_at DESC + LIMIT 1 +) +""" + + +async def fetch_data_quality_context( + pool: asyncpg.Pool, + ticker: str, + window: str, +) -> DataQualityContext | None: + """Fetch data quality metrics for the documents underlying a trend. + + Returns None if the query fails or returns no data, in which case + the suppression module will fall back to summary-based estimation. + """ + try: + row = await pool.fetchrow(_DATA_QUALITY_QUERY, ticker, window) + if row is None or row["total_documents"] == 0: + return None + + source_types_raw = row["source_types"] + source_types = set(source_types_raw) if source_types_raw else set() + + return DataQualityContext( + total_documents=int(row["total_documents"]), + valid_documents=int(row["valid_documents"] or 0), + failed_documents=int(row["failed_documents"] or 0), + avg_extraction_confidence=float(row["avg_extraction_confidence"] or 0.0), + newest_evidence_at=row["newest_evidence_at"], + source_types=source_types, + ) + except Exception: + logger.warning( + "Failed to fetch data quality context for %s/%s — will use summary fallback", + ticker, window, exc_info=True, + ) + return None + + +_LATEST_TREND_QUERY = """ +SELECT + entity_type, entity_id, window, trend_direction, trend_strength, + confidence, top_supporting_evidence, top_opposing_evidence, + dominant_catalysts, material_risks, contradiction_score, + disagreement_details, market_context, generated_at +FROM trend_windows +WHERE entity_id = $1 AND window = $2 +ORDER BY generated_at DESC +LIMIT 1 +""" + + +def _parse_trend_row(row: asyncpg.Record) -> TrendSummary: + """Convert a trend_windows row into a TrendSummary.""" + supporting = row["top_supporting_evidence"] + if isinstance(supporting, str): + supporting = json.loads(supporting) + + opposing = row["top_opposing_evidence"] + if isinstance(opposing, str): + opposing = json.loads(opposing) + + catalysts = row["dominant_catalysts"] + if isinstance(catalysts, str): + catalysts = json.loads(catalysts) + + risks = row["material_risks"] + if isinstance(risks, str): + risks = json.loads(risks) + + return TrendSummary( + entity_type=row["entity_type"], + entity_id=row["entity_id"], + window=TrendWindow(row["window"]), + trend_direction=TrendDirection(row["trend_direction"]), + trend_strength=float(row["trend_strength"]), + confidence=float(row["confidence"]), + top_supporting_evidence=supporting or [], + top_opposing_evidence=opposing or [], + dominant_catalysts=catalysts or [], + material_risks=risks or [], + contradiction_score=float(row["contradiction_score"] or 0.0), + generated_at=row["generated_at"], + ) + + +async def fetch_latest_trend( + pool: asyncpg.Pool, + ticker: str, + window: str, +) -> TrendSummary | None: + """Fetch the most recent trend summary for a ticker and window.""" + row = await pool.fetchrow(_LATEST_TREND_QUERY, ticker, window) + if row is None: + return None + return _parse_trend_row(row) + + +# --------------------------------------------------------------------------- +# Build thesis from trend summary (deterministic, no LLM) +# --------------------------------------------------------------------------- + + +def build_thesis( + summary: TrendSummary, + result: EligibilityResult, +) -> str: + """Generate a deterministic thesis string from trend data. + + This is the descriptive analysis portion (Requirement 7.2). + The LLM wording layer is a separate optional task. + """ + direction = summary.trend_direction.value + ticker = summary.entity_id + window = summary.window.value + strength = summary.trend_strength + confidence = summary.confidence + + parts: list[str] = [] + + # Opening: direction and strength + parts.append( + f"{ticker} shows a {direction} trend over the {window} window " + + f"with strength {strength:.2f} and confidence {confidence:.2f}." + ) + + # Catalysts + if summary.dominant_catalysts: + catalyst_str = ", ".join(summary.dominant_catalysts[:3]) + parts.append(f"Dominant catalysts: {catalyst_str}.") + + # Contradiction note (Requirement 7.2 — separate descriptive from prescriptive) + if summary.contradiction_score > 0.15: + parts.append( + "Notable signal disagreement detected " + + f"(contradiction score: {summary.contradiction_score:.2f})." + ) + + # Risks + if summary.material_risks: + risk_str = "; ".join(summary.material_risks[:2]) + parts.append(f"Key risks: {risk_str}.") + + # Evidence count + supporting_count = len(summary.top_supporting_evidence) + opposing_count = len(summary.top_opposing_evidence) + parts.append( + f"Based on {supporting_count} supporting and " + + f"{opposing_count} opposing evidence documents." + ) + + # Prescriptive action (separated per Requirement 7.2) + action = result.action.value.upper() + mode = result.mode.value.replace("_", " ") + parts.append(f"Recommendation: {action} ({mode}).") + + return " ".join(parts) + + +# --------------------------------------------------------------------------- +# Build risk classification (Requirement 7.2) +# --------------------------------------------------------------------------- + + +def classify_risk( + summary: TrendSummary, + result: EligibilityResult, +) -> str: + """Assign a risk classification label based on signal quality. + + Returns one of: low, moderate, high, very_high. + """ + score = 0.0 + + # Contradiction raises risk + score += summary.contradiction_score * 2.0 + + # Low confidence raises risk + score += (1.0 - summary.confidence) * 1.5 + + # Low evidence count raises risk + evidence_count = len(summary.top_supporting_evidence) + len(summary.top_opposing_evidence) + if evidence_count < 3: + score += 1.0 + elif evidence_count < 5: + score += 0.5 + + # Rejection reasons raise risk + score += len(result.rejection_reasons) * 0.5 + + if score >= 3.0: + return "very_high" + if score >= 2.0: + return "high" + if score >= 1.0: + return "moderate" + return "low" + + +# --------------------------------------------------------------------------- +# Build Recommendation from eligibility result +# --------------------------------------------------------------------------- + + +def build_recommendation( + summary: TrendSummary, + result: EligibilityResult, + reference_time: datetime | None = None, + llm_thesis: str | None = None, + suppression_result: SuppressionResult | None = None, +) -> Recommendation: + """Assemble a Recommendation object from a trend summary and eligibility result. + + Combines all evidence refs (supporting + opposing) into the recommendation + so the full decision trace is available (Requirement 8.3). + + If ``llm_thesis`` is provided (from the optional LLM wording layer), + it replaces the deterministic thesis text while preserving the risk + classification prefix. + + If ``suppression_result`` indicates suppression, a suppression note + is appended to the thesis for audit visibility (Requirement 7.4). + """ + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + # Combine evidence refs — supporting first, then opposing + evidence_refs = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence) + + deterministic_thesis = build_thesis(summary, result) + risk_class = classify_risk(summary, result) + + # Use LLM-rewritten thesis if available, otherwise deterministic + thesis_body = llm_thesis if llm_thesis else deterministic_thesis + + # Append suppression note if suppressed (Requirement 7.4) + if suppression_result and suppression_result.suppressed: + reason_strs = [r.value for r in suppression_result.reasons] + thesis_body += ( + f" [SUPPRESSED: data quality below threshold " + f"(score={suppression_result.data_quality_score:.2f}, " + f"reasons={', '.join(reason_strs)})]" + ) + + # Track whether the thesis was LLM-generated for audit + if llm_thesis: + provider = "ollama" + model_name = "thesis-rewrite" + prompt_version = THESIS_PROMPT_VERSION + else: + provider = "deterministic" + model_name = "eligibility-v1" + prompt_version = "" + + return Recommendation( + ticker=summary.entity_id, + action=result.action, + mode=result.mode, + confidence=summary.confidence, + time_horizon=result.time_horizon, + thesis=f"[risk:{risk_class}] {thesis_body}", + invalidation_conditions=result.invalidation_conditions, + position_sizing=PositionSizing( + portfolio_pct=result.position_sizing.portfolio_pct, + max_loss_pct=result.position_sizing.max_loss_pct, + ), + evidence_refs=evidence_refs, + model_metadata=ModelMetadata( + provider=provider, + model_name=model_name, + prompt_version=prompt_version, + schema_version="1.0.0", + ), + generated_at=reference_time, + ) + + +# --------------------------------------------------------------------------- +# Persist recommendation to PostgreSQL +# --------------------------------------------------------------------------- + +_INSERT_RECOMMENDATION = """ +INSERT INTO recommendations ( + ticker, action, mode, confidence, time_horizon, + thesis, invalidation_conditions, portfolio_pct, max_loss_pct, + model_version, model_provider, prompt_version, schema_version, + risk_classification, generated_at +) VALUES ( + $1, $2, $3, $4, $5, + $6, $7::jsonb, $8, $9, + $10, $11, $12, $13, + $14, $15 +) +RETURNING id +""" + +_INSERT_REC_EVIDENCE = """ +INSERT INTO recommendation_evidence ( + recommendation_id, document_id, evidence_type, weight +) VALUES ($1, $2::uuid, $3, $4) +""" + +_INSERT_RISK_EVALUATION = """ +INSERT INTO risk_evaluations ( + recommendation_id, eligible, allowed_mode, rejection_reasons, risk_checks, evaluated_at +) VALUES ($1::uuid, $2, $3, $4::jsonb, $5::jsonb, $6) +""" + +_FETCH_RECOMMENDATION = """ +SELECT + id, ticker, action, mode, confidence, time_horizon, + thesis, invalidation_conditions, portfolio_pct, max_loss_pct, + model_version, model_provider, prompt_version, schema_version, + risk_classification, generated_at +FROM recommendations +WHERE id = $1::uuid +""" + +_FETCH_REC_EVIDENCE = """ +SELECT document_id, evidence_type, weight +FROM recommendation_evidence +WHERE recommendation_id = $1::uuid +ORDER BY evidence_type, weight DESC +""" + +_FETCH_LATEST_RECS_FOR_TICKER = """ +SELECT + id, ticker, action, mode, confidence, time_horizon, + thesis, invalidation_conditions, portfolio_pct, max_loss_pct, + model_version, model_provider, prompt_version, schema_version, + risk_classification, generated_at +FROM recommendations +WHERE ticker = $1 +ORDER BY generated_at DESC +LIMIT $2 +""" + + +def _extract_risk_classification(thesis: str) -> str: + """Extract the risk classification from the thesis prefix.""" + if thesis.startswith("[risk:"): + end = thesis.find("]") + if end > 6: + return thesis[6:end] + return "moderate" + + +async def persist_recommendation( + pool: asyncpg.Pool, + rec: Recommendation, + supporting_ids: list[str], + opposing_ids: list[str], + eligibility_result: EligibilityResult | None = None, +) -> str: + """Insert a recommendation, evidence citations, and risk evaluation. + + Persists the full model metadata and risk classification for audit + trail (Requirement 8.3). Also writes the eligibility decision to + the risk_evaluations table when provided. + + Returns the recommendation UUID. + """ + risk_class = _extract_risk_classification(rec.thesis) + + row = await pool.fetchrow( + _INSERT_RECOMMENDATION, + rec.ticker, + rec.action.value, + rec.mode.value, + rec.confidence, + rec.time_horizon, + rec.thesis, + json.dumps(rec.invalidation_conditions), + rec.position_sizing.portfolio_pct, + rec.position_sizing.max_loss_pct, + rec.model_metadata.model_name, + rec.model_metadata.provider, + rec.model_metadata.prompt_version, + rec.model_metadata.schema_version, + risk_class, + rec.generated_at, + ) + rec_id = str(row["id"]) + + # Insert evidence citations with position-based weighting + evidence_rows: list[tuple[str, str, str, float]] = [] + for idx, doc_id in enumerate(supporting_ids): + weight = round(1.0 / (1.0 + idx * 0.1), 4) # rank decay + evidence_rows.append((rec_id, doc_id, "supporting", weight)) + for idx, doc_id in enumerate(opposing_ids): + weight = round(1.0 / (1.0 + idx * 0.1), 4) + evidence_rows.append((rec_id, doc_id, "opposing", weight)) + + if evidence_rows: + await pool.executemany(_INSERT_REC_EVIDENCE, evidence_rows) + + # Persist the eligibility/risk evaluation for audit trail + if eligibility_result is not None: + rejection_reasons_json = json.dumps( + [r.value for r in eligibility_result.rejection_reasons] + ) + risk_checks = { + "time_horizon": eligibility_result.time_horizon, + "position_sizing": { + "portfolio_pct": eligibility_result.position_sizing.portfolio_pct, + "max_loss_pct": eligibility_result.position_sizing.max_loss_pct, + }, + "invalidation_conditions": eligibility_result.invalidation_conditions, + "risk_classification": risk_class, + } + await pool.execute( + _INSERT_RISK_EVALUATION, + rec_id, + eligibility_result.eligible, + eligibility_result.mode.value, + rejection_reasons_json, + json.dumps(risk_checks), + rec.generated_at, + ) + + return rec_id + + +async def fetch_recommendation_by_id( + pool: asyncpg.Pool, + recommendation_id: str, +) -> dict[str, object] | None: + """Fetch a persisted recommendation with its evidence citations. + + Returns a dict with the recommendation fields and an 'evidence' list, + or None if not found. + """ + row = await pool.fetchrow(_FETCH_RECOMMENDATION, recommendation_id) + if row is None: + return None + + rec_dict = dict(row) + # Parse JSONB fields + if isinstance(rec_dict.get("invalidation_conditions"), str): + rec_dict["invalidation_conditions"] = json.loads(rec_dict["invalidation_conditions"]) + + # Fetch evidence + evidence_rows = await pool.fetch(_FETCH_REC_EVIDENCE, recommendation_id) + rec_dict["evidence"] = [ + { + "document_id": str(e["document_id"]), + "evidence_type": e["evidence_type"], + "weight": float(e["weight"]), + } + for e in evidence_rows + ] + + return rec_dict + + +async def fetch_latest_recommendations( + pool: asyncpg.Pool, + ticker: str, + limit: int = 10, +) -> list[dict[str, object]]: + """Fetch the most recent recommendations for a ticker. + + Returns a list of recommendation dicts (without evidence — use + fetch_recommendation_by_id for full detail). + """ + rows = await pool.fetch(_FETCH_LATEST_RECS_FOR_TICKER, ticker, limit) + results = [] + for row in rows: + rec_dict = dict(row) + if isinstance(rec_dict.get("invalidation_conditions"), str): + rec_dict["invalidation_conditions"] = json.loads(rec_dict["invalidation_conditions"]) + results.append(rec_dict) + return results + + +# --------------------------------------------------------------------------- +# Main entry point: generate recommendation for a ticker +# --------------------------------------------------------------------------- + + +async def generate_recommendation( + pool: asyncpg.Pool, + ticker: str, + window: str = TrendWindow.SEVEN_DAY.value, + config: EligibilityConfig | None = None, + reference_time: datetime | None = None, + ollama_config: OllamaConfig | None = None, + suppression_config: SuppressionConfig | None = None, + minio_client: Minio | None = None, +) -> Recommendation | None: + """Generate and persist a recommendation for a ticker from its latest trend. + + Steps: + 1. Fetch the latest trend summary for the ticker + window. + 2. Evaluate data quality suppression (Requirement 7.4). + 3. Evaluate eligibility using deterministic rules. + 4. Build a Recommendation object with thesis and evidence. + - If ``ollama_config`` is provided, the deterministic thesis is + rewritten into analyst-quality prose via the LLM wording layer. + 5. Persist the recommendation and evidence citations. + + Returns the Recommendation, or None if no trend data exists. + """ + if reference_time is None: + reference_time = datetime.now(timezone.utc) + + cfg = config or EligibilityConfig() + sup_cfg = suppression_config or SuppressionConfig() + + # 1. Fetch latest trend + summary = await fetch_latest_trend(pool, ticker, window) + if summary is None: + logger.info("No trend data for %s/%s — skipping recommendation", ticker, window) + return None + + # 2. Evaluate data quality suppression (Requirement 7.4) + quality_ctx = await fetch_data_quality_context(pool, ticker, window) + suppression = evaluate_suppression( + summary, quality_ctx=quality_ctx, config=sup_cfg, reference_time=reference_time, + ) + + # 3. Evaluate eligibility + result = evaluate_eligibility(summary, cfg) + + # Apply suppression: force mode to informational if suppressed + if suppression.suppressed: + result = EligibilityResult( + eligible=False, + action=result.action, + mode=RecommendationMode.INFORMATIONAL, + position_sizing=result.position_sizing, + rejection_reasons=result.rejection_reasons, + time_horizon=result.time_horizon, + invalidation_conditions=result.invalidation_conditions, + ) + + # 4. Optional LLM thesis rewrite + llm_thesis: str | None = None + if ollama_config is not None: + deterministic_thesis = build_thesis(summary, result) + llm_thesis = await rewrite_thesis_with_llm( + deterministic_thesis=deterministic_thesis, + summary=summary, + config=ollama_config, + ) + # If the LLM returned the same text as the deterministic thesis, + # treat it as a no-op (fallback was used). + if llm_thesis == deterministic_thesis: + llm_thesis = None + + # 5. Build recommendation + rec = build_recommendation( + summary, result, reference_time, llm_thesis=llm_thesis, + suppression_result=suppression, + ) + + # 6. Persist recommendation, evidence citations, and risk evaluation + rec_id = await persist_recommendation( + pool, + rec, + supporting_ids=list(summary.top_supporting_evidence), + opposing_ids=list(summary.top_opposing_evidence), + eligibility_result=result, + ) + + # 7. Publish prediction facts to analytical tables (Requirement 9.4) + if minio_client is not None: + try: + lake_refs = publish_recommendation_facts( + minio_client, + rec, + trend_direction=summary.trend_direction.value, + trend_strength=summary.trend_strength, + ) + logger.info( + "Published analytical facts for %s: %s", + ticker, lake_refs, + ) + except Exception: + logger.warning( + "Failed to publish analytical facts for %s/%s — recommendation " + "persisted but lake publication failed", + ticker, rec_id, exc_info=True, + ) + + logger.info( + "Generated recommendation %s for %s: action=%s mode=%s confidence=%.3f " + "eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s", + rec_id, ticker, rec.action.value, rec.mode.value, rec.confidence, + result.eligible, suppression.suppressed, suppression.data_quality_score, + llm_thesis is not None, + ) + + # Prometheus metrics + RECOMMENDATION_GENERATED.labels(action=rec.action.value, mode=rec.mode.value).inc() + RECOMMENDATION_CONFIDENCE.observe(rec.confidence) + if suppression.suppressed: + RECOMMENDATION_SUPPRESSED.inc() + + return rec + + +# --------------------------------------------------------------------------- +# Batch: generate recommendations for multiple tickers +# --------------------------------------------------------------------------- + + +async def generate_recommendations_batch( + pool: asyncpg.Pool, + tickers: list[str], + window: str = TrendWindow.SEVEN_DAY.value, + config: EligibilityConfig | None = None, + ollama_config: OllamaConfig | None = None, + suppression_config: SuppressionConfig | None = None, + minio_client: Minio | None = None, +) -> list[Recommendation]: + """Generate recommendations for a list of tickers. + + Processes each ticker sequentially. Returns only the successfully + generated recommendations (tickers with no trend data are skipped). + + If ``ollama_config`` is provided, each recommendation's thesis will + be rewritten using the LLM wording layer. + """ + results: list[Recommendation] = [] + reference_time = datetime.now(timezone.utc) + + for ticker in tickers: + rec = await generate_recommendation( + pool, ticker, window, config, reference_time, + ollama_config=ollama_config, + suppression_config=suppression_config, + minio_client=minio_client, + ) + if rec is not None: + results.append(rec) + + logger.info( + "Batch recommendation: %d/%d tickers produced recommendations", + len(results), len(tickers), + ) + return results diff --git a/services/risk/app.py b/services/risk/app.py new file mode 100644 index 0000000..c07217f --- /dev/null +++ b/services/risk/app.py @@ -0,0 +1,101 @@ +"""Risk Engine API - FastAPI application for order risk evaluation and approval workflow.""" +from __future__ import annotations + +from contextlib import asynccontextmanager + +import asyncpg +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from services.risk.approval import ( + expire_stale_approvals, + get_approval_by_id, + get_pending_approvals, + review_approval, +) +from services.risk.engine import ( + AccountRiskState, + PortfolioRiskConfig, + ProposedOrder, + RiskEvaluation, + evaluate_order, +) +from services.shared.config import load_config +from services.shared.logging import setup_logging + +config = load_config() +pool: asyncpg.Pool | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global pool + setup_logging("risk_engine", level=config.log_level, json_output=config.json_logs) + pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) + yield + if pool: + await pool.close() + + +app = FastAPI(title="Stonks Oracle - Risk Engine", lifespan=lifespan) + + +class EvaluateRequest(BaseModel): + order: ProposedOrder + config: PortfolioRiskConfig | None = None + state: AccountRiskState | None = None + + +@app.post("/evaluate", response_model=RiskEvaluation) +async def evaluate(req: EvaluateRequest) -> RiskEvaluation: + risk_config = req.config or PortfolioRiskConfig() + return evaluate_order(req.order, risk_config, req.state) + + +@app.get("/health") +async def health(): + return {"status": "ok"} + + +class ReviewRequest(BaseModel): + approved: bool + reviewed_by: str = "operator" + review_note: str = "" + + +@app.get("/approvals/pending") +async def list_pending(): + if not pool: + raise HTTPException(503, "Database not ready") + requests = await get_pending_approvals(pool) + return [r.to_dict() for r in requests] + + +@app.get("/approvals/{approval_id}") +async def get_approval(approval_id: str): + if not pool: + raise HTTPException(503, "Database not ready") + req = await get_approval_by_id(pool, approval_id) + if not req: + raise HTTPException(404, "Approval not found") + return req.to_dict() + + +@app.post("/approvals/{approval_id}/review") +async def review(approval_id: str, body: ReviewRequest): + if not pool: + raise HTTPException(503, "Database not ready") + status = await review_approval( + pool, approval_id, body.approved, body.reviewed_by, body.review_note, + ) + if status is None: + raise HTTPException(404, "Approval not found or no longer pending") + return {"approval_id": approval_id, "status": status.value} + + +@app.post("/approvals/expire") +async def expire(): + if not pool: + raise HTTPException(503, "Database not ready") + expired = await expire_stale_approvals(pool) + return {"expired": expired} diff --git a/services/risk/approval.py b/services/risk/approval.py new file mode 100644 index 0000000..ff2c672 --- /dev/null +++ b/services/risk/approval.py @@ -0,0 +1,300 @@ +"""Operator approval workflow for live trading mode. + +When live trading is enabled and operator approval is required, +orders are held in a pending state until an operator explicitly +approves or rejects them. Expired approvals are treated as rejections. + +Requirements: 8.2 +Design: Section 4.8 - Risk Engine (operator approval rules) +""" +from __future__ import annotations + +import json +import logging +import uuid +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any + +import asyncpg + +from services.risk.engine import ( + OperatorApproval, + PortfolioRiskConfig, + TradingMode, +) + +logger = logging.getLogger("operator_approval") + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class ApprovalStatus(str, Enum): + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + EXPIRED = "expired" + + +# --------------------------------------------------------------------------- +# Core logic: does this order need approval? +# --------------------------------------------------------------------------- + + +def requires_approval( + config: PortfolioRiskConfig, + trading_mode: TradingMode | None = None, +) -> bool: + """Determine whether an order requires operator approval. + + Paper orders are auto-approved when auto_approve_paper is True. + Live orders require approval when require_approval_for_live is True. + Disabled mode always returns False (orders are blocked upstream). + """ + mode = trading_mode or config.trading_mode + + if mode == TradingMode.DISABLED: + return False + + if mode == TradingMode.PAPER: + return not config.operator_approval.auto_approve_paper + + # Live mode + return config.operator_approval.require_approval_for_live + + +def compute_expiry( + config: PortfolioRiskConfig, + now: datetime | None = None, +) -> datetime: + """Compute the expiry timestamp for a new approval request.""" + now = now or datetime.now(timezone.utc) + return now + timedelta(minutes=config.operator_approval.approval_timeout_minutes) + + +# --------------------------------------------------------------------------- +# Approval request model (in-memory representation) +# --------------------------------------------------------------------------- + + +class ApprovalRequest: + """Represents a pending operator approval request.""" + + def __init__( + self, + approval_id: str | None = None, + order_job: dict[str, Any] | None = None, + recommendation_id: str | None = None, + ticker: str = "", + side: str = "buy", + quantity: float = 0.0, + estimated_value: float = 0.0, + risk_evaluation_id: str | None = None, + status: ApprovalStatus = ApprovalStatus.PENDING, + requested_by: str = "system", + reviewed_by: str | None = None, + review_note: str | None = None, + expires_at: datetime | None = None, + requested_at: datetime | None = None, + reviewed_at: datetime | None = None, + ) -> None: + self.approval_id = approval_id or str(uuid.uuid4()) + self.order_job = order_job or {} + self.recommendation_id = recommendation_id + self.ticker = ticker + self.side = side + self.quantity = quantity + self.estimated_value = estimated_value + self.risk_evaluation_id = risk_evaluation_id + self.status = status + self.requested_by = requested_by + self.reviewed_by = reviewed_by + self.review_note = review_note + self.expires_at = expires_at or (datetime.now(timezone.utc) + timedelta(minutes=30)) + self.requested_at = requested_at or datetime.now(timezone.utc) + self.reviewed_at = reviewed_at + + @property + def is_pending(self) -> bool: + return self.status == ApprovalStatus.PENDING + + @property + def is_expired(self) -> bool: + if self.status == ApprovalStatus.EXPIRED: + return True + if self.status == ApprovalStatus.PENDING: + return datetime.now(timezone.utc) >= self.expires_at + return False + + def to_dict(self) -> dict[str, Any]: + return { + "approval_id": self.approval_id, + "recommendation_id": self.recommendation_id, + "ticker": self.ticker, + "side": self.side, + "quantity": self.quantity, + "estimated_value": self.estimated_value, + "risk_evaluation_id": self.risk_evaluation_id, + "status": self.status.value, + "requested_by": self.requested_by, + "reviewed_by": self.reviewed_by, + "review_note": self.review_note, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "requested_at": self.requested_at.isoformat() if self.requested_at else None, + "reviewed_at": self.reviewed_at.isoformat() if self.reviewed_at else None, + } + + +# --------------------------------------------------------------------------- +# DB persistence +# --------------------------------------------------------------------------- + +_INSERT_APPROVAL = """ +INSERT INTO operator_approvals ( + id, order_job, recommendation_id, ticker, side, quantity, + estimated_value, status, risk_evaluation_id, requested_by, + expires_at, requested_at +) VALUES ( + $1::uuid, $2::jsonb, $3, $4, $5, $6, + $7, $8, $9, $10, + $11, $12 +) +""" + +_UPDATE_APPROVAL_STATUS = """ +UPDATE operator_approvals +SET status = $2, reviewed_by = $3, review_note = $4, reviewed_at = $5, updated_at = NOW() +WHERE id = $1::uuid AND status = 'pending' +RETURNING id, status +""" + +_EXPIRE_STALE_APPROVALS = """ +UPDATE operator_approvals +SET status = 'expired', updated_at = NOW() +WHERE status = 'pending' AND expires_at <= $1 +RETURNING id, ticker +""" + +_FETCH_PENDING_APPROVALS = """ +SELECT id, order_job, recommendation_id, ticker, side, quantity, + estimated_value, status, risk_evaluation_id, requested_by, + reviewed_by, review_note, expires_at, requested_at, reviewed_at +FROM operator_approvals +WHERE status = 'pending' +ORDER BY requested_at ASC +""" + +_FETCH_APPROVAL_BY_ID = """ +SELECT id, order_job, recommendation_id, ticker, side, quantity, + estimated_value, status, risk_evaluation_id, requested_by, + reviewed_by, review_note, expires_at, requested_at, reviewed_at +FROM operator_approvals +WHERE id = $1::uuid +""" + + +def _row_to_request(row: Any) -> ApprovalRequest: + """Convert a DB row to an ApprovalRequest.""" + order_job = row["order_job"] + if isinstance(order_job, str): + order_job = json.loads(order_job) + return ApprovalRequest( + approval_id=str(row["id"]), + order_job=order_job, + recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else None, + ticker=row["ticker"], + side=row["side"], + quantity=float(row["quantity"]), + estimated_value=float(row["estimated_value"]), + risk_evaluation_id=str(row["risk_evaluation_id"]) if row.get("risk_evaluation_id") else None, + status=ApprovalStatus(row["status"]), + requested_by=row["requested_by"], + reviewed_by=row["reviewed_by"], + review_note=row["review_note"], + expires_at=row["expires_at"], + requested_at=row["requested_at"], + reviewed_at=row["reviewed_at"], + ) + + +async def create_approval_request( + pool: asyncpg.Pool, + request: ApprovalRequest, +) -> str: + """Persist a new approval request. Returns the approval ID.""" + await pool.execute( + _INSERT_APPROVAL, + request.approval_id, + json.dumps(request.order_job, default=str), + request.recommendation_id, + request.ticker, + request.side, + request.quantity, + request.estimated_value, + request.status.value, + request.risk_evaluation_id, + request.requested_by, + request.expires_at, + request.requested_at, + ) + return request.approval_id + + +async def review_approval( + pool: asyncpg.Pool, + approval_id: str, + approved: bool, + reviewed_by: str = "operator", + review_note: str = "", +) -> ApprovalStatus | None: + """Approve or reject a pending approval request. + + Returns the new status, or None if the approval was not found + or was no longer pending (already expired/reviewed). + """ + now = datetime.now(timezone.utc) + new_status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED + + row = await pool.fetchrow( + _UPDATE_APPROVAL_STATUS, + approval_id, + new_status.value, + reviewed_by, + review_note, + now, + ) + if row: + return ApprovalStatus(row["status"]) + return None + + +async def expire_stale_approvals( + pool: asyncpg.Pool, + now: datetime | None = None, +) -> list[dict[str, str]]: + """Mark all expired pending approvals. Returns list of expired items.""" + now = now or datetime.now(timezone.utc) + rows = await pool.fetch(_EXPIRE_STALE_APPROVALS, now) + return [{"id": str(r["id"]), "ticker": r["ticker"]} for r in rows] + + +async def get_pending_approvals( + pool: asyncpg.Pool, +) -> list[ApprovalRequest]: + """Fetch all pending approval requests, oldest first.""" + rows = await pool.fetch(_FETCH_PENDING_APPROVALS) + return [_row_to_request(r) for r in rows] + + +async def get_approval_by_id( + pool: asyncpg.Pool, + approval_id: str, +) -> ApprovalRequest | None: + """Fetch a single approval request by ID.""" + row = await pool.fetchrow(_FETCH_APPROVAL_BY_ID, approval_id) + if row: + return _row_to_request(row) + return None diff --git a/services/risk/engine.py b/services/risk/engine.py index a4c9c76..a7d7837 100644 --- a/services/risk/engine.py +++ b/services/risk/engine.py @@ -1 +1,616 @@ -"""Risk engine - enforces guardrails, position limits, and trade eligibility checks.""" +"""Risk engine - portfolio and account risk configuration and enforcement. + +Defines the configuration and state models used to enforce guardrails +on trade execution: max position size, sector exposure, daily loss limits, +news-shock lockouts, and operator approval rules. + +Also implements the hard-block evaluation logic that decides whether a +proposed order is allowed before it reaches the broker adapter. + +Requirements: 8.1, 8.2, 8.3, 8.4, 8.5 +Design: Section 4.8 - Risk Engine +""" +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class TradingMode(str, Enum): + """Execution environment separation (Requirement 8.1).""" + PAPER = "paper" + LIVE = "live" + DISABLED = "disabled" + + +class RiskCheckResult(str, Enum): + """Outcome of a single risk check.""" + PASS = "pass" + FAIL = "fail" + WARN = "warn" + + +# --------------------------------------------------------------------------- +# Portfolio-level risk configuration (Requirement 8.2, 8.4) +# --------------------------------------------------------------------------- + + +class PositionLimits(BaseModel): + """Per-position size constraints.""" + max_position_pct: float = Field( + default=0.05, ge=0, le=1, + description="Maximum portfolio percentage for a single position", + ) + max_position_value: float = Field( + default=10_000.0, ge=0, + description="Maximum dollar value for a single position", + ) + max_shares_per_order: float = Field( + default=1000.0, ge=0, + description="Maximum shares in a single order", + ) + + +class SectorExposureLimits(BaseModel): + """Sector-level concentration limits.""" + max_sector_pct: float = Field( + default=0.25, ge=0, le=1, + description="Maximum portfolio percentage exposed to one sector", + ) + max_sectors: int = Field( + default=10, ge=1, + description="Maximum number of sectors with open positions", + ) + + +class DailyLossLimits(BaseModel): + """Daily drawdown controls.""" + max_daily_loss_pct: float = Field( + default=0.02, ge=0, le=1, + description="Maximum portfolio loss percentage in a single day before halting", + ) + max_daily_loss_value: float = Field( + default=1_000.0, ge=0, + description="Maximum dollar loss in a single day before halting", + ) + max_daily_trades: int = Field( + default=20, ge=0, + description="Maximum number of trades per day", + ) + + +class NewsShockLockout(BaseModel): + """News-shock lockout configuration. + + When a symbol has a high-impact news event, trading is paused + for a configurable cooldown period. + """ + enabled: bool = True + lockout_minutes: int = Field( + default=60, ge=0, + description="Minutes to lock out trading after a high-impact news event", + ) + impact_threshold: float = Field( + default=0.80, ge=0, le=1, + description="Minimum impact_score from document intelligence to trigger lockout", + ) + catalyst_types: list[str] = Field( + default_factory=lambda: ["earnings", "legal", "m_and_a"], + description="Catalyst types that trigger lockout when above threshold", + ) + + +class OperatorApproval(BaseModel): + """Operator approval workflow for live trading (Requirement 8.2).""" + require_approval_for_live: bool = Field( + default=True, + description="Whether live orders require operator approval", + ) + auto_approve_paper: bool = Field( + default=True, + description="Whether paper orders are auto-approved", + ) + approval_timeout_minutes: int = Field( + default=30, ge=1, + description="Minutes before a pending approval expires", + ) + + +class SymbolCooldown(BaseModel): + """Per-symbol cooldown after a trade.""" + cooldown_minutes: int = Field( + default=15, ge=0, + description="Minutes to wait before trading the same symbol again", + ) + max_open_positions_per_symbol: int = Field( + default=1, ge=1, + description="Maximum concurrent open positions for a single symbol", + ) + + +class PortfolioRiskConfig(BaseModel): + """Complete portfolio-level risk configuration. + + This is the top-level config that governs all risk checks. + Persisted in PostgreSQL and loaded at engine startup. + """ + config_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + name: str = "default" + trading_mode: TradingMode = TradingMode.PAPER + position_limits: PositionLimits = Field(default_factory=PositionLimits) + sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits) + daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits) + news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout) + operator_approval: OperatorApproval = Field(default_factory=OperatorApproval) + symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown) + active: bool = True + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_db_json(self) -> dict[str, Any]: + """Serialize the full config to a JSON-compatible dict for DB storage.""" + return self.model_dump(mode="json") + + @classmethod + def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig: + """Deserialize from a DB JSON column.""" + return cls.model_validate(data) + + +# --------------------------------------------------------------------------- +# Account risk state (runtime snapshot) +# --------------------------------------------------------------------------- + + +class AccountRiskState(BaseModel): + """Runtime snapshot of an account's risk posture. + + Computed from broker positions, today's trades, and current P&L. + Used by risk checks to evaluate whether a new order is allowed. + """ + account_id: str = "" + portfolio_value: float = 0.0 + cash: float = 0.0 + buying_power: float = 0.0 + daily_pnl: float = 0.0 + daily_trade_count: int = 0 + open_position_count: int = 0 + positions_by_symbol: dict[str, float] = Field( + default_factory=dict, + description="Map of ticker → current market value", + ) + positions_by_sector: dict[str, float] = Field( + default_factory=dict, + description="Map of sector → total market value", + ) + last_trade_times: dict[str, datetime] = Field( + default_factory=dict, + description="Map of ticker → last trade timestamp for cooldown checks", + ) + locked_symbols: dict[str, datetime] = Field( + default_factory=dict, + description="Map of ticker → lockout expiry for news-shock lockouts", + ) + snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +# --------------------------------------------------------------------------- +# Risk check output (Requirement 8.3 - full decision trace) +# --------------------------------------------------------------------------- + + +class RiskCheckDetail(BaseModel): + """Result of a single risk check.""" + check_name: str + result: RiskCheckResult + message: str = "" + threshold: float | None = None + actual: float | None = None + + +class RiskEvaluation(BaseModel): + """Complete risk evaluation for a proposed order. + + Captures every check performed so the full decision trace + is reproducible (Requirement 8.3). + """ + evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + recommendation_id: str | None = None + ticker: str = "" + eligible: bool = False + allowed_mode: TradingMode = TradingMode.DISABLED + checks: list[RiskCheckDetail] = Field(default_factory=list) + rejection_reasons: list[str] = Field(default_factory=list) + config_snapshot: PortfolioRiskConfig | None = None + state_snapshot: AccountRiskState | None = None + evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + @property + def passed(self) -> bool: + return self.eligible and len(self.rejection_reasons) == 0 + + +# --------------------------------------------------------------------------- +# Default configuration +# --------------------------------------------------------------------------- + +DEFAULT_RISK_CONFIG = PortfolioRiskConfig() + + +# --------------------------------------------------------------------------- +# Proposed order (input to risk evaluation) +# --------------------------------------------------------------------------- + + +class ProposedOrder(BaseModel): + """A proposed order to be evaluated by the risk engine before submission. + + This is the input to evaluate_order(). It carries enough context + for every risk check to run without external lookups. + """ + recommendation_id: str | None = None + ticker: str + sector: str = "" + action: str = "buy" # buy | sell + quantity: float = 0.0 + estimated_value: float = 0.0 + confidence: float = 0.0 + + +# --------------------------------------------------------------------------- +# Individual risk checks (Requirement 8.4) +# --------------------------------------------------------------------------- + + +def _check_trading_mode( + config: PortfolioRiskConfig, +) -> RiskCheckDetail: + """Block all orders when trading is disabled.""" + if config.trading_mode == TradingMode.DISABLED: + return RiskCheckDetail( + check_name="trading_mode", + result=RiskCheckResult.FAIL, + message="Trading is disabled", + ) + return RiskCheckDetail( + check_name="trading_mode", + result=RiskCheckResult.PASS, + message=f"Trading mode: {config.trading_mode.value}", + ) + + +def _check_max_position_size( + order: ProposedOrder, + config: PortfolioRiskConfig, + state: AccountRiskState, +) -> list[RiskCheckDetail]: + """Enforce per-position size limits (value, percentage, shares).""" + checks: list[RiskCheckDetail] = [] + limits = config.position_limits + + # Check max position value + existing_value = state.positions_by_symbol.get(order.ticker, 0.0) + new_total_value = existing_value + order.estimated_value + checks.append(RiskCheckDetail( + check_name="max_position_value", + result=( + RiskCheckResult.PASS + if new_total_value <= limits.max_position_value + else RiskCheckResult.FAIL + ), + message=( + f"Position value {new_total_value:.2f} " + f"{'within' if new_total_value <= limits.max_position_value else 'exceeds'} " + f"limit {limits.max_position_value:.2f}" + ), + threshold=limits.max_position_value, + actual=new_total_value, + )) + + # Check max position percentage of portfolio + if state.portfolio_value > 0: + position_pct = new_total_value / state.portfolio_value + else: + position_pct = 1.0 if new_total_value > 0 else 0.0 + checks.append(RiskCheckDetail( + check_name="max_position_pct", + result=( + RiskCheckResult.PASS + if position_pct <= limits.max_position_pct + else RiskCheckResult.FAIL + ), + message=( + f"Position {position_pct:.4f} of portfolio " + f"{'within' if position_pct <= limits.max_position_pct else 'exceeds'} " + f"limit {limits.max_position_pct:.4f}" + ), + threshold=limits.max_position_pct, + actual=position_pct, + )) + + # Check max shares per order + checks.append(RiskCheckDetail( + check_name="max_shares_per_order", + result=( + RiskCheckResult.PASS + if order.quantity <= limits.max_shares_per_order + else RiskCheckResult.FAIL + ), + message=( + f"Order quantity {order.quantity:.0f} " + f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} " + f"limit {limits.max_shares_per_order:.0f}" + ), + threshold=limits.max_shares_per_order, + actual=order.quantity, + )) + + return checks + + +def _check_sector_exposure( + order: ProposedOrder, + config: PortfolioRiskConfig, + state: AccountRiskState, +) -> RiskCheckDetail: + """Enforce sector concentration limits.""" + limits = config.sector_exposure + + if not order.sector: + return RiskCheckDetail( + check_name="sector_exposure", + result=RiskCheckResult.WARN, + message="No sector provided on order; skipping sector check", + ) + + existing_sector_value = state.positions_by_sector.get(order.sector, 0.0) + new_sector_value = existing_sector_value + order.estimated_value + + if state.portfolio_value > 0: + sector_pct = new_sector_value / state.portfolio_value + else: + sector_pct = 1.0 if new_sector_value > 0 else 0.0 + + return RiskCheckDetail( + check_name="sector_exposure", + result=( + RiskCheckResult.PASS + if sector_pct <= limits.max_sector_pct + else RiskCheckResult.FAIL + ), + message=( + f"Sector '{order.sector}' exposure {sector_pct:.4f} " + f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} " + f"limit {limits.max_sector_pct:.4f}" + ), + threshold=limits.max_sector_pct, + actual=sector_pct, + ) + + +def _check_daily_loss( + config: PortfolioRiskConfig, + state: AccountRiskState, +) -> list[RiskCheckDetail]: + """Enforce daily loss and trade count limits.""" + checks: list[RiskCheckDetail] = [] + limits = config.daily_loss + + # Daily loss percentage + if state.portfolio_value > 0: + loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value + else: + loss_pct = 0.0 + + checks.append(RiskCheckDetail( + check_name="daily_loss_pct", + result=( + RiskCheckResult.PASS + if loss_pct <= limits.max_daily_loss_pct + else RiskCheckResult.FAIL + ), + message=( + f"Daily loss {loss_pct:.4f} " + f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} " + f"limit {limits.max_daily_loss_pct:.4f}" + ), + threshold=limits.max_daily_loss_pct, + actual=loss_pct, + )) + + # Daily loss absolute value + abs_loss = abs(min(state.daily_pnl, 0.0)) + checks.append(RiskCheckDetail( + check_name="daily_loss_value", + result=( + RiskCheckResult.PASS + if abs_loss <= limits.max_daily_loss_value + else RiskCheckResult.FAIL + ), + message=( + f"Daily loss ${abs_loss:.2f} " + f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} " + f"limit ${limits.max_daily_loss_value:.2f}" + ), + threshold=limits.max_daily_loss_value, + actual=abs_loss, + )) + + # Daily trade count + checks.append(RiskCheckDetail( + check_name="daily_trade_count", + result=( + RiskCheckResult.PASS + if state.daily_trade_count < limits.max_daily_trades + else RiskCheckResult.FAIL + ), + message=( + f"Daily trades {state.daily_trade_count} " + f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} " + f"limit {limits.max_daily_trades}" + ), + threshold=float(limits.max_daily_trades), + actual=float(state.daily_trade_count), + )) + + return checks + + +def _check_news_shock_lockout( + order: ProposedOrder, + config: PortfolioRiskConfig, + state: AccountRiskState, + now: datetime | None = None, +) -> RiskCheckDetail: + """Block trading on symbols under news-shock lockout.""" + lockout_cfg = config.news_shock + + if not lockout_cfg.enabled: + return RiskCheckDetail( + check_name="news_shock_lockout", + result=RiskCheckResult.PASS, + message="News-shock lockout is disabled", + ) + + now = now or datetime.now(timezone.utc) + lockout_expiry = state.locked_symbols.get(order.ticker) + + if lockout_expiry is not None and now < lockout_expiry: + remaining = lockout_expiry - now + return RiskCheckDetail( + check_name="news_shock_lockout", + result=RiskCheckResult.FAIL, + message=( + f"Symbol {order.ticker} locked out until " + f"{lockout_expiry.isoformat()} " + f"({remaining.total_seconds():.0f}s remaining)" + ), + ) + + return RiskCheckDetail( + check_name="news_shock_lockout", + result=RiskCheckResult.PASS, + message=f"No active lockout for {order.ticker}", + ) + + +def _check_symbol_cooldown( + order: ProposedOrder, + config: PortfolioRiskConfig, + state: AccountRiskState, + now: datetime | None = None, +) -> RiskCheckDetail: + """Enforce per-symbol cooldown between trades.""" + cooldown_cfg = config.symbol_cooldown + now = now or datetime.now(timezone.utc) + + last_trade = state.last_trade_times.get(order.ticker) + if last_trade is not None: + cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes) + if now < cooldown_end: + remaining = cooldown_end - now + return RiskCheckDetail( + check_name="symbol_cooldown", + result=RiskCheckResult.FAIL, + message=( + f"Symbol {order.ticker} in cooldown until " + f"{cooldown_end.isoformat()} " + f"({remaining.total_seconds():.0f}s remaining)" + ), + ) + + return RiskCheckDetail( + check_name="symbol_cooldown", + result=RiskCheckResult.PASS, + message=f"No active cooldown for {order.ticker}", + ) + + +# --------------------------------------------------------------------------- +# Main evaluation entry point (Requirements 8.3, 8.4, 8.5) +# --------------------------------------------------------------------------- + + +def evaluate_order( + order: ProposedOrder, + config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG, + state: AccountRiskState | None = None, + now: datetime | None = None, +) -> RiskEvaluation: + """Evaluate a proposed order against all risk controls. + + Runs every hard-block check and returns a RiskEvaluation capturing + the full decision trace (Requirement 8.3). If any check fails, + the order is rejected before broker submission (Requirement 8.4). + + The engine fails closed: if state is missing or ambiguous, the + order is rejected (Requirement 8.5). + """ + state = state or AccountRiskState() + now = now or datetime.now(timezone.utc) + + all_checks: list[RiskCheckDetail] = [] + rejection_reasons: list[str] = [] + + # 1. Trading mode gate + mode_check = _check_trading_mode(config) + all_checks.append(mode_check) + if mode_check.result == RiskCheckResult.FAIL: + rejection_reasons.append(mode_check.message) + + # 2. Position size limits + position_checks = _check_max_position_size(order, config, state) + all_checks.extend(position_checks) + for c in position_checks: + if c.result == RiskCheckResult.FAIL: + rejection_reasons.append(c.message) + + # 3. Sector exposure + sector_check = _check_sector_exposure(order, config, state) + all_checks.append(sector_check) + if sector_check.result == RiskCheckResult.FAIL: + rejection_reasons.append(sector_check.message) + + # 4. Daily loss limits + daily_checks = _check_daily_loss(config, state) + all_checks.extend(daily_checks) + for c in daily_checks: + if c.result == RiskCheckResult.FAIL: + rejection_reasons.append(c.message) + + # 5. News-shock lockout + lockout_check = _check_news_shock_lockout(order, config, state, now) + all_checks.append(lockout_check) + if lockout_check.result == RiskCheckResult.FAIL: + rejection_reasons.append(lockout_check.message) + + # 6. Symbol cooldown + cooldown_check = _check_symbol_cooldown(order, config, state, now) + all_checks.append(cooldown_check) + if cooldown_check.result == RiskCheckResult.FAIL: + rejection_reasons.append(cooldown_check.message) + + # Determine eligibility and allowed mode + eligible = len(rejection_reasons) == 0 + allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED + + return RiskEvaluation( + recommendation_id=order.recommendation_id, + ticker=order.ticker, + eligible=eligible, + allowed_mode=allowed_mode, + checks=all_checks, + rejection_reasons=rejection_reasons, + config_snapshot=config, + state_snapshot=state, + evaluated_at=now, + ) diff --git a/services/scheduler/app.py b/services/scheduler/app.py index 5808241..cee0ba1 100644 --- a/services/scheduler/app.py +++ b/services/scheduler/app.py @@ -1,14 +1,23 @@ -"""Scheduler - triggers ingestion cycles for tracked symbols and sources.""" +"""Scheduler - triggers ingestion cycles for tracked symbols and sources. + +Polls the symbol registry for active companies and their configured sources, +respects per-source polling cadences and backoff windows, coordinates rate +limits across source types, and enqueues ingestion jobs for downstream workers. + +Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 +""" import asyncio import json import logging -from datetime import datetime, timedelta +from datetime import datetime +from typing import Any, Optional import asyncpg import redis.asyncio as aioredis from services.shared.config import load_config from services.shared.db import get_pg_pool, get_redis +from services.shared.logging import setup_logging from services.shared.redis_keys import ( QUEUE_INGESTION, lock_key, @@ -16,11 +25,11 @@ from services.shared.redis_keys import ( rate_limit_key, ) -logging.basicConfig(level=logging.INFO) logger = logging.getLogger("scheduler") -# Polling cadences by source class (seconds) -CADENCES = { +# Default polling cadences by source class (seconds). +# Individual sources can override via config.polling_interval_seconds. +DEFAULT_CADENCES: dict[str, int] = { "market_api": 60, "news_api": 300, "filings_api": 3600, @@ -28,81 +37,267 @@ CADENCES = { "broker": 30, } +# Default rate limits per source type (requests per minute) +DEFAULT_RATE_LIMITS: dict[str, int] = { + "market_api": 30, + "news_api": 20, + "filings_api": 10, + "web_scrape": 10, + "broker": 60, +} + +# How long to wait before retrying a failed source (seconds) +DEFAULT_BACKOFF_BASE: int = 60 +MAX_BACKOFF: int = 3600 +MAX_RETRY_COUNT: int = 10 + +# Main loop interval (seconds) +SCHEDULER_TICK: int = 15 + + +def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int: + """Return the polling interval for a source. + + Uses the source's config.polling_interval_seconds if set, + otherwise falls back to the default cadence for the source type. + """ + if config and "polling_interval_seconds" in config: + try: + return max(10, int(config["polling_interval_seconds"])) + except (ValueError, TypeError): + pass + return DEFAULT_CADENCES.get(source_type, 600) + + +def compute_backoff(retry_count: int) -> int: + """Exponential backoff with a cap. Returns seconds to wait.""" + delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8)) + return min(delay, MAX_BACKOFF) + + +def is_source_due( + source_type: str, + source_config: Optional[dict[str, Any]], + last_completed_at: Optional[datetime], + last_status: Optional[str], + retry_count: int, + next_retry_at: Optional[datetime], + now: datetime, +) -> bool: + """Determine whether a source is due for its next polling cycle. + + Checks: + - If the source has never run, it is due. + - If the last run failed and we have a next_retry_at in the future, skip. + - If the last run failed and retry_count exceeds max, skip (needs manual reset). + - Otherwise, check if enough time has elapsed since the last completed run. + """ + # Never run before — always due + if last_completed_at is None and last_status is None: + return True + + # If last run failed, respect backoff + if last_status == "failed": + if retry_count >= MAX_RETRY_COUNT: + return False + if next_retry_at and now < next_retry_at.replace(tzinfo=None): + return False + # Backoff elapsed or no next_retry_at set — allow retry + return True + + # If currently running, don't double-schedule + if last_status == "running": + return False + + # Normal cadence check + if last_completed_at is None: + return True + + cadence = get_cadence_for_source(source_type, source_config) + elapsed = (now - last_completed_at.replace(tzinfo=None)).total_seconds() + return elapsed >= cadence + + +def build_job_payload( + source: Any, + aliases: list[str], + now: datetime, +) -> dict[str, Any]: + """Build the ingestion job payload for a source.""" + return { + "source_id": str(source["source_id"]), + "company_id": str(source["company_id"]), + "ticker": source["ticker"], + "legal_name": source["legal_name"], + "aliases": aliases, + "source_type": source["source_type"], + "source_name": source["source_name"], + "config": dict(source["config"]) if source["config"] else {}, + "credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5, + "scheduled_at": now.isoformat(), + } + async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool: + """Acquire a distributed lock. Returns True if acquired.""" return await rds.set(lock_key(name), "1", nx=True, ex=ttl) -async def release_lock(rds: aioredis.Redis, name: str): +async def release_lock(rds: aioredis.Redis, name: str) -> None: + """Release a distributed lock.""" await rds.delete(lock_key(name)) -async def check_rate_limit(rds: aioredis.Redis, source_type: str, max_per_minute: int = 30) -> bool: - key = rate_limit_key(source_type, datetime.utcnow().strftime("%Y%m%d%H%M")) +async def check_rate_limit( + rds: aioredis.Redis, + source_type: str, + now: datetime, + max_per_minute: Optional[int] = None, +) -> bool: + """Check whether the source type is within its rate limit window. + + Returns True if the request is allowed, False if rate-limited. + """ + limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30) + window = now.strftime("%Y%m%d%H%M") + key = rate_limit_key(source_type, window) count = await rds.incr(key) if count == 1: await rds.expire(key, 120) - return count <= max_per_minute + return count <= limit -async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis): - """One scheduling pass: find due sources and enqueue ingestion jobs.""" - sources = await pool.fetch( - """SELECT s.id as source_id, s.company_id, s.source_type, s.source_name, s.config, - c.ticker, c.legal_name - FROM sources s JOIN companies c ON s.company_id = c.id +async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: + """Fetch all active sources joined with their active companies.""" + return await pool.fetch( + """SELECT s.id AS source_id, + s.company_id, + s.source_type, + s.source_name, + s.config, + s.credibility_score, + c.ticker, + c.legal_name + FROM sources s + JOIN companies c ON s.company_id = c.id WHERE s.active = TRUE AND c.active = TRUE ORDER BY s.source_type, c.ticker""" ) + +async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]: + """Fetch all aliases for a company.""" + rows = await pool.fetch( + "SELECT alias FROM company_aliases WHERE company_id = $1", + company_id, + ) + return [r["alias"] for r in rows] + + +async def fetch_last_run( + pool: asyncpg.Pool, source_id: str +) -> Optional[asyncpg.Record]: + """Fetch the most recent ingestion run for a source.""" + return await pool.fetchrow( + """SELECT status, started_at, completed_at, retry_count, next_retry_at + FROM ingestion_runs + WHERE source_id = $1 + ORDER BY started_at DESC + LIMIT 1""", + source_id, + ) + + +async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: + """One scheduling pass: find due sources and enqueue ingestion jobs. + + Returns the number of jobs enqueued. + """ + now = datetime.utcnow() + sources = await fetch_active_sources(pool) + enqueued = 0 + skipped_rate_limit = 0 + skipped_not_due = 0 + for src in sources: + source_id = src["source_id"] source_type = src["source_type"] - cadence = CADENCES.get(source_type, 600) + source_config = dict(src["config"]) if src["config"] else None - # Check last run - last_run = await pool.fetchval( - "SELECT MAX(started_at) FROM ingestion_runs WHERE source_id = $1 AND status IN ('completed', 'running')", - src["source_id"], - ) - if last_run and (datetime.utcnow() - last_run.replace(tzinfo=None)) < timedelta(seconds=cadence): + # Check last run status and timing + last_run = await fetch_last_run(pool, source_id) + + last_completed_at = None + last_status = None + retry_count = 0 + next_retry_at = None + + if last_run: + last_status = last_run["status"] + last_completed_at = last_run["completed_at"] or last_run["started_at"] + retry_count = last_run["retry_count"] or 0 + next_retry_at = last_run["next_retry_at"] + + if not is_source_due( + source_type=source_type, + source_config=source_config, + last_completed_at=last_completed_at, + last_status=last_status, + retry_count=retry_count, + next_retry_at=next_retry_at, + now=now, + ): + skipped_not_due += 1 continue - if not await check_rate_limit(rds, source_type): - logger.warning(f"Rate limit hit for {source_type}") + # Rate limit check + if not await check_rate_limit(rds, source_type, now): + logger.warning( + "Rate limit hit for %s, skipping %s/%s", + source_type, src["ticker"], src["source_name"], + ) + skipped_rate_limit += 1 continue - job = { - "source_id": str(src["source_id"]), - "company_id": str(src["company_id"]), - "ticker": src["ticker"], - "source_type": source_type, - "source_name": src["source_name"], - "config": dict(src["config"]) if src["config"] else {}, - "scheduled_at": datetime.utcnow().isoformat(), - } - await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) + # Fetch company aliases for downstream entity matching + aliases = await fetch_aliases_for_company(pool, src["company_id"]) + + job = build_job_payload(src, aliases, now) + await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc] enqueued += 1 - if enqueued: - logger.info(f"Enqueued {enqueued} ingestion jobs") + logger.debug( + "Enqueued %s job for %s (%s)", + source_type, src["ticker"], src["source_name"], + ) + + logger.info( + "Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d", + enqueued, skipped_not_due, skipped_rate_limit, len(sources), + ) + return enqueued -async def main(): +async def main() -> None: config = load_config() + setup_logging("scheduler", level=config.log_level, json_output=config.json_logs) + pool = await get_pg_pool(config) rds = get_redis(config) - logger.info("Scheduler started") + logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK) try: while True: try: if await acquire_lock(rds, "scheduler_cycle", ttl=30): - await schedule_cycle(pool, rds) - await release_lock(rds, "scheduler_cycle") - except Exception as e: - logger.error(f"Scheduler cycle error: {e}") - await asyncio.sleep(15) + try: + await schedule_cycle(pool, rds) + finally: + await release_lock(rds, "scheduler_cycle") + except Exception: + logger.exception("Scheduler cycle error") + await asyncio.sleep(SCHEDULER_TICK) finally: await pool.close() await rds.close() diff --git a/services/shared/alerting.py b/services/shared/alerting.py new file mode 100644 index 0000000..94dd5b9 --- /dev/null +++ b/services/shared/alerting.py @@ -0,0 +1,342 @@ +"""Operational alerting for Stonks Oracle pipeline health. + +Evaluates alert rules against PostgreSQL operational state and emits +structured log events and Prometheus metrics when thresholds are breached. + +Alert rules: +- source_failures: sustained source retrieval failures per source +- schema_failure_spike: extraction validation failure rate exceeds threshold +- analytical_lag: lake publication has not completed within threshold +- broker_issues: consecutive broker submission errors + +Requirements: 12.3 +Design: Section 12 (Observability and Operations) +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +import asyncpg + +from services.shared.config import AlertingConfig +from services.shared.metrics import ( + ALERT_ACTIVE, + ALERT_CHECK_DURATION, + ALERTS_FIRED, + ALERTS_RESOLVED, +) + +logger = logging.getLogger("alerting") + + +@dataclass +class Alert: + """A single alert instance.""" + + rule: str + severity: str # "warning" | "critical" + summary: str + details: dict[str, Any] = field(default_factory=dict) + fired_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +@dataclass +class AlertState: + """Tracks which rules are currently firing to detect transitions.""" + + active: dict[str, Alert] = field(default_factory=dict) + + def fire(self, alert: Alert) -> bool: + """Record an alert firing. Returns True if this is a new firing.""" + key = f"{alert.rule}:{alert.details.get('key', '')}" + is_new = key not in self.active + self.active[key] = alert + return is_new + + def resolve(self, rule: str, key: str = "") -> bool: + """Resolve an alert. Returns True if it was previously active.""" + full_key = f"{rule}:{key}" + if full_key in self.active: + del self.active[full_key] + return True + return False + + def is_firing(self, rule: str, key: str = "") -> bool: + return f"{rule}:{key}" in self.active + + +async def check_source_failures( + pool: asyncpg.Pool, + config: AlertingConfig, +) -> list[Alert]: + """Check for sources with sustained consecutive failures. + + Queries ingestion_runs for sources where the last N runs all failed + within the lookback window. + """ + rows = await pool.fetch( + """WITH recent_runs AS ( + SELECT source_id, status, + ROW_NUMBER() OVER (PARTITION BY source_id ORDER BY started_at DESC) AS rn + FROM ingestion_runs + WHERE started_at >= NOW() - INTERVAL '1 hour' * $1 + ), + failure_streaks AS ( + SELECT source_id, + COUNT(*) FILTER (WHERE status = 'failed') AS consecutive_failures, + COUNT(*) AS total_runs + FROM recent_runs + WHERE rn <= $2 + GROUP BY source_id + HAVING COUNT(*) FILTER (WHERE status = 'failed') = COUNT(*) + AND COUNT(*) >= $2 + ) + SELECT fs.source_id, fs.consecutive_failures, + s.source_type, s.source_name, c.ticker + FROM failure_streaks fs + JOIN sources s ON s.id = fs.source_id + JOIN companies c ON c.id = s.company_id""", + config.source_failure_window_hours, + config.source_failure_threshold, + ) + + alerts = [] + for row in rows: + alerts.append(Alert( + rule="source_failures", + severity="warning", + summary=( + f"Source {row['source_name']} ({row['source_type']}) for " + f"{row['ticker']} has {row['consecutive_failures']} consecutive failures" + ), + details={ + "key": str(row["source_id"]), + "source_id": str(row["source_id"]), + "source_type": row["source_type"], + "source_name": row["source_name"], + "ticker": row["ticker"], + "consecutive_failures": row["consecutive_failures"], + }, + )) + return alerts + + +async def check_schema_failure_spike( + pool: asyncpg.Pool, + config: AlertingConfig, +) -> list[Alert]: + """Check if extraction schema validation failure rate exceeds threshold. + + Queries model_performance_metrics for the recent window and computes + the failure rate. + """ + row = await pool.fetchrow( + """SELECT + COUNT(*) AS total, + COUNT(*) FILTER (WHERE NOT success) AS failed + FROM model_performance_metrics + WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1""", + config.schema_failure_window_hours, + ) + + if not row or row["total"] == 0: + return [] + + total = row["total"] + failed = row["failed"] + failure_rate = failed / total + + if failure_rate >= config.schema_failure_rate_threshold: + return [Alert( + rule="schema_failure_spike", + severity="critical" if failure_rate >= 0.5 else "warning", + summary=( + f"Extraction schema failure rate is {failure_rate:.1%} " + f"({failed}/{total}) in the last {config.schema_failure_window_hours}h" + ), + details={ + "key": "global", + "total_extractions": total, + "failed_extractions": failed, + "failure_rate": round(failure_rate, 4), + "threshold": config.schema_failure_rate_threshold, + "window_hours": config.schema_failure_window_hours, + }, + )] + return [] + + +async def check_analytical_lag( + pool: asyncpg.Pool, + config: AlertingConfig, +) -> list[Alert]: + """Check if lake publication is lagging beyond threshold. + + Looks at the audit_events table for the most recent successful + lake_publish events per table, and alerts if any are stale. + """ + rows = await pool.fetch( + """SELECT + details->>'table_name' AS table_name, + MAX(created_at) AS last_publish + FROM audit_events + WHERE event_type = 'lake_publish' + AND details->>'status' = 'success' + AND details->>'table_name' IS NOT NULL + GROUP BY details->>'table_name' + HAVING MAX(created_at) < NOW() - INTERVAL '1 minute' * $1""", + config.lake_lag_threshold_minutes, + ) + + alerts = [] + now = datetime.now(timezone.utc) + for row in rows: + table_name = row["table_name"] + last_publish = row["last_publish"] + if last_publish.tzinfo is None: + last_publish = last_publish.replace(tzinfo=timezone.utc) + lag_minutes = (now - last_publish).total_seconds() / 60 + + alerts.append(Alert( + rule="analytical_lag", + severity="warning", + summary=( + f"Lake table '{table_name}' last published {lag_minutes:.0f}m ago " + f"(threshold: {config.lake_lag_threshold_minutes}m)" + ), + details={ + "key": table_name, + "table_name": table_name, + "last_publish": last_publish.isoformat(), + "lag_minutes": round(lag_minutes, 1), + "threshold_minutes": config.lake_lag_threshold_minutes, + }, + )) + return alerts + + +async def check_broker_issues( + pool: asyncpg.Pool, + config: AlertingConfig, +) -> list[Alert]: + """Check for consecutive broker submission errors. + + Queries order_events for recent broker-level errors (rejections, + timeouts, connection failures) within the lookback window. + """ + rows = await pool.fetch( + """WITH recent_events AS ( + SELECT order_id, event_type, created_at, + ROW_NUMBER() OVER (ORDER BY created_at DESC) AS rn + FROM order_events + WHERE created_at >= NOW() - INTERVAL '1 hour' * $1 + AND event_type IN ('broker_error', 'broker_timeout', 'connection_failed') + ) + SELECT COUNT(*) AS error_count + FROM recent_events + WHERE rn <= $2""", + config.broker_error_window_hours, + config.broker_error_threshold, + ) + + if not rows: + return [] + + error_count = rows[0]["error_count"] + if error_count >= config.broker_error_threshold: + return [Alert( + rule="broker_issues", + severity="critical", + summary=( + f"{error_count} broker errors in the last " + f"{config.broker_error_window_hours}h" + ), + details={ + "key": "global", + "error_count": error_count, + "threshold": config.broker_error_threshold, + "window_hours": config.broker_error_window_hours, + }, + )] + return [] + + +async def evaluate_alerts( + pool: asyncpg.Pool, + config: AlertingConfig, + state: AlertState, +) -> list[Alert]: + """Run all alert rules and return newly fired alerts. + + Updates AlertState to track firing/resolved transitions and emits + structured log events and Prometheus metrics for each transition. + """ + all_alerts: list[Alert] = [] + + with ALERT_CHECK_DURATION.time(): + # Collect alerts from all rules + try: + all_alerts.extend(await check_source_failures(pool, config)) + except Exception: + logger.exception("Error checking source failures") + + try: + all_alerts.extend(await check_schema_failure_spike(pool, config)) + except Exception: + logger.exception("Error checking schema failure spike") + + try: + all_alerts.extend(await check_analytical_lag(pool, config)) + except Exception: + logger.exception("Error checking analytical lag") + + try: + all_alerts.extend(await check_broker_issues(pool, config)) + except Exception: + logger.exception("Error checking broker issues") + + # Track which rule+key combos are currently firing + current_keys: set[str] = set() + newly_fired: list[Alert] = [] + + for alert in all_alerts: + key = f"{alert.rule}:{alert.details.get('key', '')}" + current_keys.add(key) + + if state.fire(alert): + # New alert firing + ALERTS_FIRED.labels(rule=alert.rule, severity=alert.severity).inc() + ALERT_ACTIVE.labels(rule=alert.rule).set(1) + newly_fired.append(alert) + logger.warning( + "ALERT FIRING: [%s] %s", + alert.rule, + alert.summary, + extra={ + "alert_rule": alert.rule, + "alert_severity": alert.severity, + "alert_details": alert.details, + }, + ) + + # Check for resolved alerts + resolved_keys = set(state.active.keys()) - current_keys + for key in resolved_keys: + rule = key.split(":")[0] + detail_key = key[len(rule) + 1:] + if state.resolve(rule, detail_key): + ALERTS_RESOLVED.labels(rule=rule).inc() + # Only set gauge to 0 if no more alerts for this rule + still_firing = any(k.startswith(f"{rule}:") for k in state.active) + if not still_firing: + ALERT_ACTIVE.labels(rule=rule).set(0) + logger.info( + "ALERT RESOLVED: [%s] key=%s", + rule, + detail_key, + ) + + return newly_fired diff --git a/services/shared/audit.py b/services/shared/audit.py new file mode 100644 index 0000000..ec22f8a --- /dev/null +++ b/services/shared/audit.py @@ -0,0 +1,493 @@ +"""Execution audit trail - records every step from recommendation to market outcome. + +Writes structured audit events to the audit_events table so the full +decision chain is traceable: recommendation → risk evaluation → order +submission → broker response → fill/rejection/cancellation. + +Each event captures the entity type, entity ID, event type, actor, +and a JSONB data payload with stage-specific details. + +Requirements: 8.3, 11.3 +Design: Section 4.9 (Broker Adapter), Section 6.1 (PostgreSQL audit_events) +""" +from __future__ import annotations + +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any + +import asyncpg + +logger = logging.getLogger("audit") + + +# --------------------------------------------------------------------------- +# Event type constants +# --------------------------------------------------------------------------- + +# Recommendation stage +AUDIT_RECOMMENDATION_GENERATED = "recommendation.generated" +AUDIT_RECOMMENDATION_SUPPRESSED = "recommendation.suppressed" + +# Risk evaluation stage +AUDIT_RISK_EVALUATED = "risk.evaluated" +AUDIT_RISK_REJECTED = "risk.rejected" + +# Order lifecycle +AUDIT_ORDER_SUBMITTED = "order.submitted" +AUDIT_ORDER_ACCEPTED = "order.accepted" +AUDIT_ORDER_FILLED = "order.filled" +AUDIT_ORDER_REJECTED = "order.rejected" +AUDIT_ORDER_CANCELLED = "order.cancelled" +AUDIT_ORDER_DUPLICATE = "order.duplicate_prevented" + +# Position changes +AUDIT_POSITION_OPENED = "position.opened" +AUDIT_POSITION_CLOSED = "position.closed" +AUDIT_POSITION_UPDATED = "position.updated" + +# Trading mode changes +AUDIT_TRADING_MODE_CHANGED = "trading.mode_changed" + +# Operator approval workflow +AUDIT_APPROVAL_REQUESTED = "approval.requested" +AUDIT_APPROVAL_APPROVED = "approval.approved" +AUDIT_APPROVAL_REJECTED = "approval.rejected" +AUDIT_APPROVAL_EXPIRED = "approval.expired" + + +# --------------------------------------------------------------------------- +# Core audit writer +# --------------------------------------------------------------------------- + +_INSERT_AUDIT_EVENT = """ +INSERT INTO audit_events (id, event_type, entity_type, entity_id, actor, data, created_at) +VALUES ($1::uuid, $2, $3, $4::uuid, $5, $6::jsonb, $7) +""" + + +async def record_audit_event( + pool: asyncpg.Pool, + event_type: str, + entity_type: str, + entity_id: str, + data: dict[str, Any], + actor: str = "system", + timestamp: datetime | None = None, +) -> str: + """Write a single audit event to PostgreSQL. + + Returns the audit event UUID. + """ + event_id = str(uuid.uuid4()) + ts = timestamp or datetime.now(timezone.utc) + + try: + await pool.execute( + _INSERT_AUDIT_EVENT, + event_id, + event_type, + entity_type, + entity_id, + actor, + json.dumps(data, default=str), + ts, + ) + except Exception: + logger.warning( + "Failed to write audit event %s for %s/%s", + event_type, entity_type, entity_id, + exc_info=True, + ) + return "" + + return event_id + + +# --------------------------------------------------------------------------- +# Convenience helpers for each execution stage +# --------------------------------------------------------------------------- + + +async def audit_recommendation_generated( + pool: asyncpg.Pool, + recommendation_id: str, + ticker: str, + action: str, + mode: str, + confidence: float, + evidence_count: int, + suppressed: bool = False, +) -> str: + """Record that a recommendation was generated.""" + event_type = AUDIT_RECOMMENDATION_SUPPRESSED if suppressed else AUDIT_RECOMMENDATION_GENERATED + return await record_audit_event( + pool, + event_type=event_type, + entity_type="recommendation", + entity_id=recommendation_id, + data={ + "ticker": ticker, + "action": action, + "mode": mode, + "confidence": confidence, + "evidence_count": evidence_count, + "suppressed": suppressed, + }, + actor="recommendation_worker", + ) + + +async def audit_risk_evaluated( + pool: asyncpg.Pool, + evaluation_id: str, + recommendation_id: str | None, + ticker: str, + eligible: bool, + allowed_mode: str, + rejection_reasons: list[str], + check_count: int, +) -> str: + """Record a risk evaluation result.""" + event_type = AUDIT_RISK_REJECTED if not eligible else AUDIT_RISK_EVALUATED + return await record_audit_event( + pool, + event_type=event_type, + entity_type="risk_evaluation", + entity_id=evaluation_id, + data={ + "recommendation_id": recommendation_id, + "ticker": ticker, + "eligible": eligible, + "allowed_mode": allowed_mode, + "rejection_reasons": rejection_reasons, + "check_count": check_count, + }, + actor="risk_engine", + ) + + +async def audit_order_submitted( + pool: asyncpg.Pool, + order_id: str, + ticker: str, + side: str, + quantity: float, + order_type: str, + idempotency_key: str, + recommendation_id: str | None = None, + evaluation_id: str | None = None, +) -> str: + """Record that an order was submitted to the broker.""" + return await record_audit_event( + pool, + event_type=AUDIT_ORDER_SUBMITTED, + entity_type="order", + entity_id=order_id, + data={ + "ticker": ticker, + "side": side, + "quantity": quantity, + "order_type": order_type, + "idempotency_key": idempotency_key, + "recommendation_id": recommendation_id, + "evaluation_id": evaluation_id, + }, + actor="broker_service", + ) + + +async def audit_order_filled( + pool: asyncpg.Pool, + order_id: str, + ticker: str, + side: str, + fill_quantity: float, + fill_price: float | None, + broker_order_id: str, +) -> str: + """Record that an order was filled by the broker.""" + return await record_audit_event( + pool, + event_type=AUDIT_ORDER_FILLED, + entity_type="order", + entity_id=order_id, + data={ + "ticker": ticker, + "side": side, + "fill_quantity": fill_quantity, + "fill_price": fill_price, + "broker_order_id": broker_order_id, + }, + actor="broker_service", + ) + + +async def audit_order_rejected( + pool: asyncpg.Pool, + order_id: str, + ticker: str, + reason: str, + source: str = "broker", +) -> str: + """Record that an order was rejected (by risk engine or broker).""" + return await record_audit_event( + pool, + event_type=AUDIT_ORDER_REJECTED, + entity_type="order", + entity_id=order_id, + data={ + "ticker": ticker, + "reason": reason, + "rejection_source": source, + }, + actor="broker_service", + ) + + +async def audit_order_cancelled( + pool: asyncpg.Pool, + order_id: str, + ticker: str, + broker_order_id: str, +) -> str: + """Record that an order was cancelled.""" + return await record_audit_event( + pool, + event_type=AUDIT_ORDER_CANCELLED, + entity_type="order", + entity_id=order_id, + data={ + "ticker": ticker, + "broker_order_id": broker_order_id, + }, + actor="broker_service", + ) + + +async def audit_duplicate_prevented( + pool: asyncpg.Pool, + order_id: str, + ticker: str, + idempotency_key: str, + detected_via: str, +) -> str: + """Record that a duplicate order was prevented.""" + return await record_audit_event( + pool, + event_type=AUDIT_ORDER_DUPLICATE, + entity_type="order", + entity_id=order_id, + data={ + "ticker": ticker, + "idempotency_key": idempotency_key, + "detected_via": detected_via, + }, + actor="broker_service", + ) + + +async def audit_position_change( + pool: asyncpg.Pool, + order_id: str, + ticker: str, + side: str, + quantity_before: float, + quantity_after: float, + avg_entry_before: float, + avg_entry_after: float, +) -> str: + """Record a position change resulting from a fill.""" + if quantity_before == 0 and quantity_after > 0: + event_type = AUDIT_POSITION_OPENED + elif quantity_after == 0: + event_type = AUDIT_POSITION_CLOSED + else: + event_type = AUDIT_POSITION_UPDATED + + return await record_audit_event( + pool, + event_type=event_type, + entity_type="position", + entity_id=order_id, + data={ + "ticker": ticker, + "side": side, + "quantity_before": quantity_before, + "quantity_after": quantity_after, + "avg_entry_before": avg_entry_before, + "avg_entry_after": avg_entry_after, + }, + actor="broker_service", + ) + + +async def audit_approval_requested( + pool: asyncpg.Pool, + approval_id: str, + ticker: str, + side: str, + quantity: float, + estimated_value: float, + recommendation_id: str | None = None, + expires_at: str | None = None, +) -> str: + """Record that an operator approval was requested for a live order.""" + return await record_audit_event( + pool, + event_type=AUDIT_APPROVAL_REQUESTED, + entity_type="approval", + entity_id=approval_id, + data={ + "ticker": ticker, + "side": side, + "quantity": quantity, + "estimated_value": estimated_value, + "recommendation_id": recommendation_id, + "expires_at": expires_at, + }, + actor="broker_service", + ) + + +async def audit_approval_reviewed( + pool: asyncpg.Pool, + approval_id: str, + ticker: str, + approved: bool, + reviewed_by: str = "operator", + review_note: str = "", +) -> str: + """Record that an operator reviewed an approval request.""" + event_type = AUDIT_APPROVAL_APPROVED if approved else AUDIT_APPROVAL_REJECTED + return await record_audit_event( + pool, + event_type=event_type, + entity_type="approval", + entity_id=approval_id, + data={ + "ticker": ticker, + "approved": approved, + "reviewed_by": reviewed_by, + "review_note": review_note, + }, + actor=reviewed_by, + ) + + +async def audit_approval_expired( + pool: asyncpg.Pool, + approval_id: str, + ticker: str, +) -> str: + """Record that an approval request expired without review.""" + return await record_audit_event( + pool, + event_type=AUDIT_APPROVAL_EXPIRED, + entity_type="approval", + entity_id=approval_id, + data={"ticker": ticker}, + actor="system", + ) + + +async def audit_trading_mode_changed( + pool: asyncpg.Pool, + config_id: str, + old_mode: str, + new_mode: str, + actor: str = "operator", +) -> str: + """Record a trading mode change.""" + return await record_audit_event( + pool, + event_type=AUDIT_TRADING_MODE_CHANGED, + entity_type="risk_config", + entity_id=config_id, + data={ + "old_mode": old_mode, + "new_mode": new_mode, + }, + actor=actor, + ) + + +# --------------------------------------------------------------------------- +# Query helpers for audit trail retrieval (Requirement 11.3) +# --------------------------------------------------------------------------- + +_FETCH_AUDIT_TRAIL_FOR_ORDER = """ +SELECT id, event_type, entity_type, entity_id, actor, data, created_at +FROM audit_events +WHERE entity_id = $1::uuid + OR data->>'recommendation_id' = $2 + OR data->>'order_id' = $2 +ORDER BY created_at ASC +""" + +_FETCH_AUDIT_TRAIL_BY_ENTITY = """ +SELECT id, event_type, entity_type, entity_id, actor, data, created_at +FROM audit_events +WHERE entity_type = $1 AND entity_id = $2::uuid +ORDER BY created_at ASC +""" + +_FETCH_FULL_EXECUTION_TRAIL = """ +SELECT id, event_type, entity_type, entity_id, actor, data, created_at +FROM audit_events +WHERE entity_id = $1::uuid + OR entity_id IN ( + SELECT entity_id FROM audit_events + WHERE data->>'recommendation_id' = $2 + ) +ORDER BY created_at ASC +""" + + +async def get_order_audit_trail( + pool: asyncpg.Pool, + order_id: str, + recommendation_id: str | None = None, +) -> list[dict[str, Any]]: + """Fetch the full audit trail for an order, including related recommendation and risk events. + + Returns events ordered chronologically so the full decision chain + is visible: recommendation → risk → order → fill/reject. + """ + ref_id = recommendation_id or order_id + rows = await pool.fetch(_FETCH_AUDIT_TRAIL_FOR_ORDER, order_id, ref_id) + return [ + { + "id": str(row["id"]), + "event_type": row["event_type"], + "entity_type": row["entity_type"], + "entity_id": str(row["entity_id"]), + "actor": row["actor"], + "data": row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]), + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + for row in rows + ] + + +async def get_entity_audit_trail( + pool: asyncpg.Pool, + entity_type: str, + entity_id: str, +) -> list[dict[str, Any]]: + """Fetch all audit events for a specific entity.""" + rows = await pool.fetch(_FETCH_AUDIT_TRAIL_BY_ENTITY, entity_type, entity_id) + return [ + { + "id": str(row["id"]), + "event_type": row["event_type"], + "entity_type": row["entity_type"], + "entity_id": str(row["entity_id"]), + "actor": row["actor"], + "data": row["data"] if isinstance(row["data"], dict) else json.loads(row["data"]), + "created_at": row["created_at"].isoformat() if row["created_at"] else None, + } + for row in rows + ] diff --git a/services/shared/config.py b/services/shared/config.py index c5bdc81..d052ade 100644 --- a/services/shared/config.py +++ b/services/shared/config.py @@ -43,6 +43,10 @@ class OllamaConfig: base_url: str = "http://localhost:11434" model: str = "llama3.1:8b" timeout: int = 120 + max_retries: int = 2 + retry_base_delay: float = 1.0 + retry_max_delay: float = 10.0 + retry_backoff_multiplier: float = 2.0 @dataclass @@ -51,16 +55,82 @@ class TrinoConfig: port: int = 8080 catalog: str = "lakehouse" schema: str = "stonks" + iceberg_catalog: str = "iceberg" + + +@dataclass +class MarketDataConfig: + api_key: str = "" + base_url: str = "https://api.polygon.io" + provider: str = "polygon" @dataclass class BrokerConfig: mode: str = "paper" # paper | live + provider: str = "alpaca" api_key: Optional[str] = None api_secret: Optional[str] = None base_url: Optional[str] = None +@dataclass +class RetentionConfig: + """Default retention periods (days) per bucket class. + + These can be overridden per-bucket via the retention_policies DB table. + The cleanup_interval_hours controls how often the retention worker runs. + """ + raw_market_days: int = 90 + raw_news_days: int = 180 + raw_filings_days: int = 365 + normalized_days: int = 180 + llm_prompts_days: int = 365 + llm_results_days: int = 365 + lakehouse_days: int = 730 + audit_days: int = 730 + cleanup_interval_hours: int = 24 + batch_size: int = 1000 + + +# Map bucket names to RetentionConfig field names +BUCKET_RETENTION_FIELDS: dict[str, str] = { + "stonks-raw-market": "raw_market_days", + "stonks-raw-news": "raw_news_days", + "stonks-raw-filings": "raw_filings_days", + "stonks-normalized": "normalized_days", + "stonks-llm-prompts": "llm_prompts_days", + "stonks-llm-results": "llm_results_days", + "stonks-lakehouse": "lakehouse_days", + "stonks-audit": "audit_days", +} + + +@dataclass +class AlertingConfig: + """Thresholds for operational alerting rules. + + Requirements: 12.3 + """ + # Source failure alerting + source_failure_threshold: int = 3 # consecutive failures before alert + source_failure_window_hours: int = 6 # lookback window + + # Schema/extraction failure spike + schema_failure_rate_threshold: float = 0.3 # 30% failure rate triggers alert + schema_failure_window_hours: int = 1 + + # Analytical (lake publication) lag + lake_lag_threshold_minutes: int = 60 # minutes since last successful publish + + # Broker issues + broker_error_threshold: int = 3 # consecutive broker errors + broker_error_window_hours: int = 1 + + # Evaluation interval + check_interval_seconds: int = 120 + + @dataclass class AppConfig: postgres: PostgresConfig = field(default_factory=PostgresConfig) @@ -68,8 +138,12 @@ class AppConfig: minio: MinioConfig = field(default_factory=MinioConfig) ollama: OllamaConfig = field(default_factory=OllamaConfig) trino: TrinoConfig = field(default_factory=TrinoConfig) + market_data: MarketDataConfig = field(default_factory=MarketDataConfig) broker: BrokerConfig = field(default_factory=BrokerConfig) + retention: RetentionConfig = field(default_factory=RetentionConfig) + alerting: AlertingConfig = field(default_factory=AlertingConfig) log_level: str = "INFO" + json_logs: bool = True def load_config() -> AppConfig: @@ -98,18 +172,52 @@ def load_config() -> AppConfig: base_url=os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"), model=os.getenv("OLLAMA_MODEL", "llama3.1:8b"), timeout=int(os.getenv("OLLAMA_TIMEOUT", "120")), + max_retries=int(os.getenv("OLLAMA_MAX_RETRIES", "2")), + retry_base_delay=float(os.getenv("OLLAMA_RETRY_BASE_DELAY", "1.0")), + retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")), + retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")), ), trino=TrinoConfig( host=os.getenv("TRINO_HOST", "localhost"), port=int(os.getenv("TRINO_PORT", "8080")), catalog=os.getenv("TRINO_CATALOG", "lakehouse"), schema=os.getenv("TRINO_SCHEMA", "stonks"), + iceberg_catalog=os.getenv("TRINO_ICEBERG_CATALOG", "iceberg"), + ), + market_data=MarketDataConfig( + api_key=os.getenv("MARKET_DATA_API_KEY", ""), + base_url=os.getenv("MARKET_DATA_BASE_URL", "https://api.polygon.io"), + provider=os.getenv("MARKET_DATA_PROVIDER", "polygon"), ), broker=BrokerConfig( mode=os.getenv("BROKER_MODE", "paper"), + provider=os.getenv("BROKER_PROVIDER", "alpaca"), api_key=os.getenv("BROKER_API_KEY", None), api_secret=os.getenv("BROKER_API_SECRET", None), base_url=os.getenv("BROKER_BASE_URL", None), ), + retention=RetentionConfig( + raw_market_days=int(os.getenv("RETENTION_RAW_MARKET_DAYS", "90")), + raw_news_days=int(os.getenv("RETENTION_RAW_NEWS_DAYS", "180")), + raw_filings_days=int(os.getenv("RETENTION_RAW_FILINGS_DAYS", "365")), + normalized_days=int(os.getenv("RETENTION_NORMALIZED_DAYS", "180")), + llm_prompts_days=int(os.getenv("RETENTION_LLM_PROMPTS_DAYS", "365")), + llm_results_days=int(os.getenv("RETENTION_LLM_RESULTS_DAYS", "365")), + lakehouse_days=int(os.getenv("RETENTION_LAKEHOUSE_DAYS", "730")), + audit_days=int(os.getenv("RETENTION_AUDIT_DAYS", "730")), + cleanup_interval_hours=int(os.getenv("RETENTION_CLEANUP_INTERVAL_HOURS", "24")), + batch_size=int(os.getenv("RETENTION_BATCH_SIZE", "1000")), + ), + alerting=AlertingConfig( + source_failure_threshold=int(os.getenv("ALERT_SOURCE_FAILURE_THRESHOLD", "3")), + source_failure_window_hours=int(os.getenv("ALERT_SOURCE_FAILURE_WINDOW_HOURS", "6")), + schema_failure_rate_threshold=float(os.getenv("ALERT_SCHEMA_FAILURE_RATE_THRESHOLD", "0.3")), + schema_failure_window_hours=int(os.getenv("ALERT_SCHEMA_FAILURE_WINDOW_HOURS", "1")), + lake_lag_threshold_minutes=int(os.getenv("ALERT_LAKE_LAG_THRESHOLD_MINUTES", "60")), + broker_error_threshold=int(os.getenv("ALERT_BROKER_ERROR_THRESHOLD", "3")), + broker_error_window_hours=int(os.getenv("ALERT_BROKER_ERROR_WINDOW_HOURS", "1")), + check_interval_seconds=int(os.getenv("ALERT_CHECK_INTERVAL_SECONDS", "120")), + ), log_level=os.getenv("LOG_LEVEL", "INFO"), + json_logs=os.getenv("JSON_LOGS", "true").lower() == "true", ) diff --git a/services/shared/content.py b/services/shared/content.py new file mode 100644 index 0000000..440875e --- /dev/null +++ b/services/shared/content.py @@ -0,0 +1,43 @@ +"""Canonical URL normalization and content hashing utilities. + +Provides consistent URL canonicalization and SHA-256 content hashing +across all ingestion adapters and pipeline stages. + +Requirements: 3.2, 3.3 +""" +import hashlib +from urllib.parse import parse_qsl, urlencode, urlparse + + +def normalize_url(url: str) -> str: + """Canonical URL normalization. + + - Lowercases scheme and host + - Strips fragments + - Strips trailing slashes from path (preserves root "/") + - Strips default ports (80, 443) + - Sorts query parameters for deterministic comparison + - Defaults scheme to https if missing + """ + parsed = urlparse(url) + scheme = (parsed.scheme or "https").lower() + netloc = (parsed.hostname or "").lower() + if parsed.port and parsed.port not in (80, 443): + netloc = f"{netloc}:{parsed.port}" + path = parsed.path.rstrip("/") or "/" + # Sort query params for deterministic ordering + query = urlencode(sorted(parse_qsl(parsed.query))) + normalized = f"{scheme}://{netloc}{path}" + if query: + normalized = f"{normalized}?{query}" + return normalized + + +def content_hash(data: bytes) -> str: + """Compute a stable SHA-256 hex digest for raw content bytes.""" + return hashlib.sha256(data).hexdigest() + + +def content_hash_str(text: str, encoding: str = "utf-8") -> str: + """Compute a stable SHA-256 hex digest for a text string.""" + return hashlib.sha256(text.encode(encoding)).hexdigest() diff --git a/services/shared/dead_letter.py b/services/shared/dead_letter.py new file mode 100644 index 0000000..7d5ad52 --- /dev/null +++ b/services/shared/dead_letter.py @@ -0,0 +1,134 @@ +"""Dead-letter queue (DLQ) support and replay tooling. + +When a worker fails to process a job after exhausting retries, the job +is pushed to a per-queue dead-letter list in Redis. Each DLQ entry +wraps the original payload with failure metadata (error message, +timestamp, attempt count) so operators can inspect and replay later. + +Replay moves items from the DLQ back to the source queue for +reprocessing. + +Requirements: 12.1 (observability), design section 8 (data flows) +""" +from __future__ import annotations + +import json +import logging +from datetime import datetime, timezone +from typing import Any + +import redis.asyncio as aioredis + +from services.shared.redis_keys import dlq_key, queue_key + +logger = logging.getLogger(__name__) + +# Default max attempts before a job is dead-lettered +DEFAULT_MAX_ATTEMPTS = 3 + + +def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def wrap_dlq_entry( + original_payload: dict[str, Any], + queue_name: str, + error: str, + attempt: int = 1, + worker: str = "", +) -> dict[str, Any]: + """Wrap an original job payload with DLQ metadata.""" + return { + "original_payload": original_payload, + "queue": queue_name, + "error": error, + "attempt": attempt, + "worker": worker, + "dead_lettered_at": _now_iso(), + } + + +async def send_to_dlq( + rds: aioredis.Redis, + queue_name: str, + original_payload: dict[str, Any], + error: str, + attempt: int = 1, + worker: str = "", +) -> None: + """Push a failed job to the dead-letter queue for *queue_name*.""" + entry = wrap_dlq_entry(original_payload, queue_name, error, attempt, worker) + await rds.rpush(dlq_key(queue_name), json.dumps(entry, default=str)) + logger.warning( + "Dead-lettered job on %s after %d attempts: %s", + queue_name, attempt, error, + extra={"queue": queue_name, "attempt": attempt}, + ) + + +async def dlq_length(rds: aioredis.Redis, queue_name: str) -> int: + """Return the number of items in the DLQ for *queue_name*.""" + return await rds.llen(dlq_key(queue_name)) + + +async def peek_dlq( + rds: aioredis.Redis, + queue_name: str, + start: int = 0, + count: int = 10, +) -> list[dict[str, Any]]: + """Return DLQ entries without removing them (for inspection).""" + raw_items = await rds.lrange(dlq_key(queue_name), start, start + count - 1) + return [json.loads(item) for item in raw_items] + + +async def replay_one(rds: aioredis.Redis, queue_name: str) -> dict[str, Any] | None: + """Pop the oldest DLQ entry and re-enqueue its original payload. + + Returns the replayed DLQ entry, or None if the DLQ is empty. + """ + raw = await rds.lpop(dlq_key(queue_name)) + if raw is None: + return None + entry = json.loads(raw) + original = entry.get("original_payload", entry) + await rds.rpush(queue_key(queue_name), json.dumps(original, default=str)) + logger.info("Replayed 1 job from DLQ back to %s", queue_name) + return entry + + +async def replay_all(rds: aioredis.Redis, queue_name: str) -> int: + """Replay every item in the DLQ back to the source queue. + + Returns the number of items replayed. + """ + count = 0 + while True: + raw = await rds.lpop(dlq_key(queue_name)) + if raw is None: + break + entry = json.loads(raw) + original = entry.get("original_payload", entry) + await rds.rpush(queue_key(queue_name), json.dumps(original, default=str)) + count += 1 + if count: + logger.info("Replayed %d jobs from DLQ back to %s", count, queue_name) + return count + + +async def purge_dlq(rds: aioredis.Redis, queue_name: str) -> int: + """Delete all items from the DLQ for *queue_name*. Returns count removed.""" + key = dlq_key(queue_name) + length = await rds.llen(key) + if length: + await rds.delete(key) + return length + + +async def dlq_summary(rds: aioredis.Redis, queue_names: list[str]) -> dict[str, int]: + """Return a mapping of queue_name -> DLQ depth for the given queues.""" + result: dict[str, int] = {} + for name in queue_names: + result[name] = await rds.llen(dlq_key(name)) + return result diff --git a/services/shared/dedupe.py b/services/shared/dedupe.py new file mode 100644 index 0000000..172cf07 --- /dev/null +++ b/services/shared/dedupe.py @@ -0,0 +1,198 @@ +"""Cross-source deduplication for articles and filings. + +Detects duplicate documents across different source types (news_api, +filings_api, web_scrape) using a layered approach: + +1. Redis fast-path: check content_hash and canonical_url markers for + recently-seen documents (TTL-bounded, cheap). +2. PostgreSQL fallback: query the documents table by canonical_url or + content_hash for durable cross-source matching. + +When a duplicate is detected the caller receives the existing document_id +so it can link additional company mentions without re-inserting the document. + +Requirements: 3.2, 3.3 +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +import asyncpg +import redis.asyncio as aioredis + +from services.shared.content import content_hash_str, normalize_url +from services.shared.redis_keys import DEDUPE_PREFIX + +logger = logging.getLogger("dedupe") + +# Redis TTL for dedupe markers (24 hours) +DEDUPE_TTL_SECONDS: int = 86400 + + +def _url_dedupe_key(canonical_url: str) -> str: + """Build a Redis key for URL-based deduplication.""" + return f"{DEDUPE_PREFIX}:url:{content_hash_str(canonical_url)}" + + +def _hash_dedupe_key(content_hash: str) -> str: + """Build a Redis key for content-hash-based deduplication.""" + return f"{DEDUPE_PREFIX}:{content_hash}" + + +@dataclass +class DedupeResult: + """Result of a deduplication check.""" + + is_duplicate: bool + existing_document_id: str | None = None + match_type: str | None = None # "content_hash" | "canonical_url" | None + + +async def check_duplicate( + pool: asyncpg.Pool, + rds: aioredis.Redis, + *, + content_hash: str, + url: str | None = None, + canonical_url: str | None = None, +) -> DedupeResult: + """Check whether a document is a duplicate across all source types. + + Checks in order of cost: + 1. Redis content_hash marker (fast path) + 2. Redis canonical_url marker (fast path) + 3. PostgreSQL documents.content_hash (durable) + 4. PostgreSQL documents.canonical_url (cross-source) + + Returns a DedupeResult indicating whether the document already exists. + """ + # Resolve canonical URL if only raw URL provided + resolved_canonical = canonical_url or (normalize_url(url) if url else None) + + # --- Redis fast path: content hash --- + if content_hash: + redis_key = _hash_dedupe_key(content_hash) + cached_id = await rds.get(redis_key) + if cached_id: + logger.debug("Dedupe hit (redis content_hash) for %s", content_hash[:16]) + return DedupeResult( + is_duplicate=True, + existing_document_id=str(cached_id), + match_type="content_hash", + ) + + # --- Redis fast path: canonical URL --- + if resolved_canonical: + url_key = _url_dedupe_key(resolved_canonical) + cached_id = await rds.get(url_key) + if cached_id: + logger.debug("Dedupe hit (redis canonical_url) for %s", resolved_canonical[:60]) + return DedupeResult( + is_duplicate=True, + existing_document_id=str(cached_id), + match_type="canonical_url", + ) + + # --- PostgreSQL fallback: content hash --- + if content_hash: + row = await pool.fetchrow( + "SELECT id FROM documents WHERE content_hash = $1 LIMIT 1", + content_hash, + ) + if row: + doc_id = str(row["id"]) + # Warm the Redis cache for future checks + await _set_dedupe_markers(rds, content_hash, resolved_canonical, doc_id) + logger.debug("Dedupe hit (pg content_hash) for %s", content_hash[:16]) + return DedupeResult( + is_duplicate=True, + existing_document_id=doc_id, + match_type="content_hash", + ) + + # --- PostgreSQL fallback: canonical URL --- + if resolved_canonical: + row = await pool.fetchrow( + "SELECT id FROM documents WHERE canonical_url = $1 LIMIT 1", + resolved_canonical, + ) + if row: + doc_id = str(row["id"]) + await _set_dedupe_markers(rds, content_hash, resolved_canonical, doc_id) + logger.debug("Dedupe hit (pg canonical_url) for %s", resolved_canonical[:60]) + return DedupeResult( + is_duplicate=True, + existing_document_id=doc_id, + match_type="canonical_url", + ) + + return DedupeResult(is_duplicate=False) + + +async def mark_as_seen( + rds: aioredis.Redis, + *, + content_hash: str, + canonical_url: str | None, + document_id: str, +) -> None: + """Mark a newly-persisted document in Redis for fast future dedupe checks.""" + await _set_dedupe_markers(rds, content_hash, canonical_url, document_id) + + +async def _set_dedupe_markers( + rds: aioredis.Redis, + content_hash: str | None, + canonical_url: str | None, + document_id: str, +) -> None: + """Set Redis dedupe markers for both content hash and canonical URL.""" + if content_hash: + await rds.set( + _hash_dedupe_key(content_hash), document_id, ex=DEDUPE_TTL_SECONDS + ) + if canonical_url: + await rds.set( + _url_dedupe_key(canonical_url), document_id, ex=DEDUPE_TTL_SECONDS + ) + + +async def dedupe_items( + pool: asyncpg.Pool, + rds: aioredis.Redis, + items: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Partition a list of ingestion items into new and duplicate groups. + + Each item is expected to have at least one of: + - content_hash: SHA-256 of the raw content + - url / canonical_url: the document URL + + Returns (new_items, duplicate_items). + """ + new_items: list[dict[str, Any]] = [] + dup_items: list[dict[str, Any]] = [] + + for item in items: + item_hash = item.get("content_hash", "") + item_url = item.get("url") or item.get("link") + item_canonical = item.get("canonical_url") + + result = await check_duplicate( + pool, + rds, + content_hash=item_hash, + url=item_url, + canonical_url=item_canonical, + ) + + if result.is_duplicate: + item["_dedupe_match_type"] = result.match_type + item["_dedupe_existing_id"] = result.existing_document_id + dup_items.append(item) + else: + new_items.append(item) + + return new_items, dup_items diff --git a/services/shared/logging.py b/services/shared/logging.py new file mode 100644 index 0000000..28e7c96 --- /dev/null +++ b/services/shared/logging.py @@ -0,0 +1,224 @@ +"""Structured logging and distributed tracing for all Stonks Oracle services. + +Provides: +- JSON-formatted structured log output for machine-parseable log aggregation +- Trace context (trace_id, span_id, service) propagated through log records +- Context manager for creating trace spans within a service +- Helper to configure logging for any service worker or API + +Requirements: 12.1 +Design: Section 12 (Observability and Operations) +""" +from __future__ import annotations + +import json +import logging +import time +import uuid +from contextvars import ContextVar +from datetime import datetime, timezone +from typing import Any + +# --------------------------------------------------------------------------- +# Trace context stored in contextvars for async-safe propagation +# --------------------------------------------------------------------------- + +_trace_id: ContextVar[str] = ContextVar("trace_id", default="") +_span_id: ContextVar[str] = ContextVar("span_id", default="") +_service_name: ContextVar[str] = ContextVar("service_name", default="unknown") + + +def get_trace_id() -> str: + return _trace_id.get() + + +def get_span_id() -> str: + return _span_id.get() + + +def get_service_name() -> str: + return _service_name.get() + + +def set_trace_context( + trace_id: str | None = None, + span_id: str | None = None, + service: str | None = None, +) -> None: + """Set trace context for the current async task / thread.""" + if trace_id is not None: + _trace_id.set(trace_id) + if span_id is not None: + _span_id.set(span_id) + if service is not None: + _service_name.set(service) + + +def new_trace_id() -> str: + return uuid.uuid4().hex[:16] + + +def new_span_id() -> str: + return uuid.uuid4().hex[:8] + + +# --------------------------------------------------------------------------- +# Span context manager for tracing within a service +# --------------------------------------------------------------------------- + + +class Span: + """Lightweight span for distributed tracing. + + Usage:: + + with Span("process_document", ticker="AAPL") as span: + # ... do work ... + span.set_attribute("doc_count", 5) + + On exit the span logs its duration and attributes as a structured event. + """ + + def __init__(self, operation: str, **attributes: Any) -> None: + self.operation = operation + self.parent_span_id = get_span_id() + self.span_id = new_span_id() + self.trace_id = get_trace_id() or new_trace_id() + self.attributes: dict[str, Any] = dict(attributes) + self.start_time: float = 0.0 + self.duration_ms: float = 0.0 + self._token_trace: Any = None + self._token_span: Any = None + self._logger = logging.getLogger(get_service_name() or "tracing") + + def set_attribute(self, key: str, value: Any) -> None: + self.attributes[key] = value + + def __enter__(self) -> Span: + self.start_time = time.monotonic() + self._token_trace = _trace_id.set(self.trace_id) + self._token_span = _span_id.set(self.span_id) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.duration_ms = (time.monotonic() - self.start_time) * 1000 + status = "error" if exc_type else "ok" + + self._logger.info( + "span.end", + extra={ + "span_operation": self.operation, + "span_status": status, + "span_duration_ms": round(self.duration_ms, 2), + "span_parent_id": self.parent_span_id, + "span_attributes": self.attributes, + }, + ) + + # Restore parent span context + if self._token_span is not None: + _span_id.reset(self._token_span) + if self._token_trace is not None: + _trace_id.reset(self._token_trace) + + +# --------------------------------------------------------------------------- +# JSON log formatter +# --------------------------------------------------------------------------- + + +class JSONFormatter(logging.Formatter): + """Emit each log record as a single JSON line with trace context.""" + + def format(self, record: logging.LogRecord) -> str: + log_entry: dict[str, Any] = { + "timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "service": get_service_name(), + "trace_id": get_trace_id(), + "span_id": get_span_id(), + } + + # Merge extra fields from Span or manual extra={} usage + for key in ( + "span_operation", "span_status", "span_duration_ms", + "span_parent_id", "span_attributes", + "ticker", "document_id", "source_type", "job_id", + "duration_ms", "error", "count", + ): + val = getattr(record, key, None) + if val is not None: + log_entry[key] = val + + if record.exc_info and record.exc_info[1]: + log_entry["exception"] = self.formatException(record.exc_info) + + return json.dumps(log_entry, default=str) + + +# --------------------------------------------------------------------------- +# Setup helper +# --------------------------------------------------------------------------- + + +def setup_logging( + service_name: str, + level: str = "INFO", + json_output: bool = True, +) -> None: + """Configure structured logging for a service. + + Call this once at service startup (before any log calls). + + Args: + service_name: Identifies this service in log output (e.g. "ingestion_worker"). + level: Log level string (DEBUG, INFO, WARNING, ERROR). + json_output: If True, emit JSON lines. If False, use a human-readable format. + """ + _service_name.set(service_name) + + root = logging.getLogger() + root.setLevel(getattr(logging, level.upper(), logging.INFO)) + + # Remove existing handlers to avoid duplicate output + root.handlers.clear() + + handler = logging.StreamHandler() + if json_output: + handler.setFormatter(JSONFormatter()) + else: + handler.setFormatter(logging.Formatter( + "%(asctime)s [%(levelname)s] %(name)s (%(service)s) " + "trace=%(trace_id)s span=%(span_id)s — %(message)s", + defaults={"service": service_name, "trace_id": "", "span_id": ""}, + )) + root.addHandler(handler) + + +# --------------------------------------------------------------------------- +# Trace context propagation through job payloads +# --------------------------------------------------------------------------- + + +def inject_trace_context(payload: dict[str, Any]) -> dict[str, Any]: + """Inject current trace context into a job payload dict. + + Call this before enqueuing a job to Redis so the downstream + worker can continue the same trace. + """ + trace_id = get_trace_id() + if trace_id: + payload["_trace_id"] = trace_id + return payload + + +def extract_trace_context(payload: dict[str, Any]) -> None: + """Extract and set trace context from an incoming job payload. + + Call this at the start of job processing. If no trace context + is present, generates a new trace_id. + """ + trace_id = payload.get("_trace_id") or new_trace_id() + set_trace_context(trace_id=trace_id, span_id=new_span_id()) diff --git a/services/shared/metadata.py b/services/shared/metadata.py new file mode 100644 index 0000000..564a418 --- /dev/null +++ b/services/shared/metadata.py @@ -0,0 +1,696 @@ +"""Metadata persistence for market payloads, documents, and broker events. + +Persists structured metadata records to PostgreSQL for all ingested artifacts. +Each source type has its own persistence path: +- market_api → market_snapshots table +- news_api / filings_api / web_scrape → documents + document_company_mentions +- broker → order_events or market_snapshots (for position/account snapshots) + +Requirements: 3.3, 3.4, 8.3, 9.2 +""" +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone +from typing import Any + +import asyncpg + +from services.shared.content import content_hash_str, normalize_url + +logger = logging.getLogger("metadata") + + +async def persist_market_snapshot( + pool: asyncpg.Pool, + *, + company_id: str | None, + ticker: str, + snapshot_type: str, + data: dict[str, Any], + source_provider: str, + storage_ref: str, + content_hash: str, + captured_at: datetime | None = None, +) -> str: + """Persist a market data snapshot to PostgreSQL. + + Returns the snapshot row UUID. + """ + ts = captured_at or datetime.now(timezone.utc) + row_id = await pool.fetchval( + """INSERT INTO market_snapshots + (company_id, ticker, snapshot_type, data, source_provider, + captured_at, storage_ref, content_hash) + VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7, $8) + RETURNING id""", + company_id, + ticker, + snapshot_type, + json.dumps(data), + source_provider, + ts, + storage_ref, + content_hash, + ) + logger.debug("Persisted market snapshot %s for %s", row_id, ticker) + return str(row_id) + + +async def persist_document( + pool: asyncpg.Pool, + *, + document_type: str, + source_type: str, + publisher: str, + url: str | None, + canonical_url: str | None, + title: str, + published_at: datetime | None, + content_hash: str, + storage_ref: str, + language: str = "en", +) -> str | None: + """Persist a document metadata record to PostgreSQL. + + Returns the document row UUID, or None if a duplicate content_hash exists. + """ + exists = await pool.fetchval( + "SELECT 1 FROM documents WHERE content_hash = $1", content_hash + ) + if exists: + return None + + doc_id = await pool.fetchval( + """INSERT INTO documents + (document_type, source_type, publisher, url, canonical_url, + title, published_at, content_hash, raw_storage_ref, + language, status) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, 'ingested') + RETURNING id""", + document_type, + source_type, + publisher, + url, + canonical_url, + title, + published_at, + content_hash, + storage_ref, + language, + ) + logger.debug("Persisted document %s (%s)", doc_id, title[:60] if title else "") + return str(doc_id) + + +async def update_document_parse_results( + pool: asyncpg.Pool, + *, + document_id: str, + normalized_storage_ref: str | None, + parser_output_ref: str | None, + parse_quality_score: float, + parse_confidence: str, + status: str, +) -> None: + """Update a document row with parser output references and quality scores. + + Called after the parsing stage to persist normalized text location, + structured parser output location, quality score, and confidence. + + Requirements: 4.1, 4.3, 9.1 + """ + await pool.execute( + """UPDATE documents SET + normalized_storage_ref = $2, + parser_output_ref = $3, + parse_quality_score = $4, + parse_confidence = $5, + status = $6, + updated_at = NOW() + WHERE id = $1""", + document_id, + normalized_storage_ref, + parser_output_ref, + parse_quality_score, + parse_confidence, + status, + ) + logger.debug( + "Updated document %s parse results: quality=%.2f confidence=%s status=%s", + document_id, parse_quality_score, parse_confidence, status, + ) + + +async def persist_document_company_mention( + pool: asyncpg.Pool, + *, + document_id: str, + company_id: str, + ticker: str, + mention_type: str = "direct", + confidence: float = 1.0, +) -> str: + """Link a document to a company via document_company_mentions. + + Returns the mention row UUID. + """ + mention_id = await pool.fetchval( + """INSERT INTO document_company_mentions + (document_id, company_id, ticker, mention_type, confidence) + VALUES ($1::uuid, $2::uuid, $3, $4, $5) + RETURNING id""", + document_id, + company_id, + ticker, + mention_type, + confidence, + ) + return str(mention_id) + + +async def persist_broker_event( + pool: asyncpg.Pool, + *, + ticker: str, + event_type: str, + data: dict[str, Any], + source_provider: str, + storage_ref: str, + content_hash: str, + captured_at: datetime | None = None, +) -> str: + """Persist a broker event snapshot to market_snapshots. + + Broker position/account snapshots are stored as market_snapshots + with snapshot_type prefixed by 'broker_' (e.g. broker_positions, + broker_account, broker_orders). + + Returns the snapshot row UUID. + """ + ts = captured_at or datetime.now(timezone.utc) + row_id = await pool.fetchval( + """INSERT INTO market_snapshots + (ticker, snapshot_type, data, source_provider, + captured_at, storage_ref, content_hash) + VALUES ($1, $2, $3::jsonb, $4, $5, $6, $7) + RETURNING id""", + ticker, + f"broker_{event_type}", + json.dumps(data), + source_provider, + ts, + storage_ref, + content_hash, + ) + logger.debug("Persisted broker event %s for %s", row_id, ticker) + return str(row_id) + + +def _resolve_document_type(source_type: str) -> str: + """Map source_type to a document_type value.""" + mapping = { + "news_api": "article", + "filings_api": "filing", + "web_scrape": "press_release", + } + return mapping.get(source_type, "article") + + +def _extract_publisher(item: dict[str, Any]) -> str: + """Extract publisher name from an adapter item dict.""" + if item.get("publisher"): + return str(item["publisher"]) + source = item.get("source") + if isinstance(source, dict): + return source.get("name", "") + if source: + return str(source) + return "" + + +def _parse_published_at(item: dict[str, Any]) -> datetime | None: + """Parse published_at from various adapter item formats.""" + raw = item.get("publishedAt") or item.get("published_at") + if not raw: + return None + if isinstance(raw, datetime): + return raw + try: + return datetime.fromisoformat(str(raw).replace("Z", "+00:00")) + except (ValueError, TypeError): + return None + + +async def persist_ingestion_items( + pool: asyncpg.Pool, + *, + source_type: str, + ticker: str, + company_id: str | None, + items: list[dict[str, Any]], + storage_ref: str, + adapter_metadata: dict[str, Any], + content_hash: str, +) -> tuple[int, list[str]]: + """Route ingestion items to the correct persistence path. + + Returns (new_item_count, list_of_new_ids). + """ + if source_type == "market_api": + return await _persist_market_items( + pool, + ticker=ticker, + company_id=company_id, + items=items, + storage_ref=storage_ref, + provider=adapter_metadata.get("provider", "unknown"), + content_hash=content_hash, + ) + + if source_type == "broker": + return await _persist_broker_items( + pool, + ticker=ticker, + items=items, + storage_ref=storage_ref, + provider=adapter_metadata.get("provider", "unknown"), + endpoint=adapter_metadata.get("endpoint", "positions"), + content_hash=content_hash, + ) + + # Document types: news_api, filings_api, web_scrape + return await _persist_document_items( + pool, + source_type=source_type, + ticker=ticker, + company_id=company_id, + items=items, + storage_ref=storage_ref, + ) + + +async def _persist_market_items( + pool: asyncpg.Pool, + *, + ticker: str, + company_id: str | None, + items: list[dict[str, Any]], + storage_ref: str, + provider: str, + content_hash: str, +) -> tuple[int, list[str]]: + """Persist market data items as market_snapshots rows.""" + ids: list[str] = [] + for item in items: + item_hash = content_hash_str(json.dumps(item, sort_keys=True)) + # Skip duplicates + exists = await pool.fetchval( + "SELECT 1 FROM market_snapshots WHERE content_hash = $1", item_hash + ) + if exists: + continue + + snapshot_type = _infer_market_snapshot_type(item) + row_id = await persist_market_snapshot( + pool, + company_id=company_id, + ticker=ticker, + snapshot_type=snapshot_type, + data=item, + source_provider=provider, + storage_ref=storage_ref, + content_hash=item_hash, + ) + ids.append(row_id) + return len(ids), ids + + +def _infer_market_snapshot_type(item: dict[str, Any]) -> str: + """Infer snapshot_type from market data item fields.""" + # Polygon aggregate bars have 'o', 'h', 'l', 'c' fields + if all(k in item for k in ("o", "h", "l", "c")): + return "bar" + # Ticker details have 'market_cap' or 'sic_code' + if "market_cap" in item or "sic_code" in item: + return "ticker_details" + # Quote snapshots + if "ask" in item or "bid" in item: + return "quote" + return "snapshot" + + +async def _persist_broker_items( + pool: asyncpg.Pool, + *, + ticker: str, + items: list[dict[str, Any]], + storage_ref: str, + provider: str, + endpoint: str, + content_hash: str, +) -> tuple[int, list[str]]: + """Persist broker fetch items as market_snapshots with broker_ prefix.""" + ids: list[str] = [] + for item in items: + item_hash = content_hash_str(json.dumps(item, sort_keys=True)) + exists = await pool.fetchval( + "SELECT 1 FROM market_snapshots WHERE content_hash = $1", item_hash + ) + if exists: + continue + + row_id = await persist_broker_event( + pool, + ticker=ticker, + event_type=endpoint, + data=item, + source_provider=provider, + storage_ref=storage_ref, + content_hash=item_hash, + ) + ids.append(row_id) + return len(ids), ids + + +async def _persist_document_items( + pool: asyncpg.Pool, + *, + source_type: str, + ticker: str, + company_id: str | None, + items: list[dict[str, Any]], + storage_ref: str, +) -> tuple[int, list[str]]: + """Persist document items (news, filings, web scrape) to documents table.""" + doc_type = _resolve_document_type(source_type) + ids: list[str] = [] + + for item in items: + item_hash = item.get("content_hash") or content_hash_str( + json.dumps(item, sort_keys=True) + ) + title = item.get("title", item.get("name", "")) + url = item.get("url", item.get("link", "")) + canonical_url = item.get("canonical_url") or ( + normalize_url(url) if url else None + ) + published_at = _parse_published_at(item) + publisher = _extract_publisher(item) + + doc_id = await persist_document( + pool, + document_type=doc_type, + source_type=source_type, + publisher=publisher, + url=url or None, + canonical_url=canonical_url, + title=title, + published_at=published_at, + content_hash=item_hash, + storage_ref=storage_ref, + ) + if doc_id is None: + continue + + # Link document to company if we have a company_id + if company_id: + await persist_document_company_mention( + pool, + document_id=doc_id, + company_id=company_id, + ticker=ticker, + ) + + ids.append(doc_id) + + return len(ids), ids + + +# --- Retry and failure tracking (Requirement 3.4) --- + +# Backoff constants — match scheduler defaults for consistency +RETRY_BACKOFF_BASE: int = 60 +RETRY_BACKOFF_MAX: int = 3600 +RETRY_MAX_COUNT: int = 10 + + +def compute_next_retry_at( + retry_count: int, + now: datetime | None = None, + base: int = RETRY_BACKOFF_BASE, + cap: int = RETRY_BACKOFF_MAX, +) -> datetime: + """Compute the next eligible retry time using exponential backoff. + + Args: + retry_count: Current retry count (before incrementing). + now: Reference timestamp (defaults to UTC now). + base: Base delay in seconds. + cap: Maximum delay in seconds. + + Returns: + Datetime of the next eligible retry. + """ + ts = now or datetime.now(timezone.utc) + delay = min(base * (2 ** min(retry_count, 8)), cap) + return ts + timedelta(seconds=delay) + + +async def get_source_retry_count( + pool: asyncpg.Pool, + source_id: str, +) -> int: + """Return the retry count from the most recent failed run for a source. + + If the last run succeeded or no runs exist, returns 0. + """ + row = await pool.fetchrow( + """SELECT status, retry_count + FROM ingestion_runs + WHERE source_id = $1::uuid + ORDER BY started_at DESC + LIMIT 1""", + source_id, + ) + if row and row["status"] == "failed": + return row["retry_count"] or 0 + return 0 + + +async def record_retrieval_failure( + pool: asyncpg.Pool, + run_id: str, + source_id: str, + error_message: str, + retry_count: int | None = None, + now: datetime | None = None, +) -> dict[str, Any]: + """Record a source retrieval failure with retry policy state. + + Updates the ingestion_runs row with: + - error_message: the failure reason + - retry_count: incremented from the previous failed run (or provided) + - next_retry_at: computed via exponential backoff + - status: 'failed' + + If retry_count is not provided, it is looked up from the most recent + failed run for the same source and incremented. + + Returns a dict with the recorded retry state for observability. + + Requirement 3.4 + """ + ts = now or datetime.now(timezone.utc) + + if retry_count is None: + prev_count = await get_source_retry_count(pool, source_id) + retry_count = prev_count + 1 + else: + retry_count = retry_count + 1 + + next_retry = compute_next_retry_at(retry_count - 1, now=ts) + exhausted = retry_count >= RETRY_MAX_COUNT + + await pool.execute( + """UPDATE ingestion_runs + SET status = 'failed', + error_message = $2, + retry_count = $3, + next_retry_at = $4, + completed_at = $5 + WHERE id = $1""", + run_id, + error_message, + retry_count, + next_retry, + ts, + ) + + state = { + "run_id": run_id, + "source_id": source_id, + "retry_count": retry_count, + "next_retry_at": next_retry.isoformat(), + "exhausted": exhausted, + "error_message": error_message, + } + + if exhausted: + logger.warning( + "Source %s exhausted retries (%d/%d): %s", + source_id, retry_count, RETRY_MAX_COUNT, error_message, + ) + else: + logger.info( + "Source %s failed (retry %d/%d), next retry at %s: %s", + source_id, retry_count, RETRY_MAX_COUNT, + next_retry.isoformat(), error_message, + ) + + return state + + +async def persist_document_intelligence( + pool: asyncpg.Pool, + *, + document_id: str, + summary: str, + macro_themes: list[str], + novelty_score: float, + source_credibility: float, + extraction_warnings: list[str], + confidence: float, + model_provider: str, + model_name: str, + prompt_version: str, + schema_version: str, + raw_output_ref: str | None = None, + prompt_ref: str | None = None, + validation_status: str = "valid", + validation_errors: list[str] | None = None, + retry_count: int = 0, +) -> str: + """Persist a document intelligence record to PostgreSQL. + + Returns the intelligence row UUID. + + Requirements: 5.3, 5.4, 9.2 + """ + intel_id = await pool.fetchval( + """INSERT INTO document_intelligence + (document_id, summary, macro_themes, novelty_score, + source_credibility, extraction_warnings, confidence, + model_provider, model_name, prompt_version, schema_version, + raw_output_ref, prompt_ref, validation_status, + validation_errors, retry_count) + VALUES ($1::uuid, $2, $3::jsonb, $4, $5, $6::jsonb, $7, + $8, $9, $10, $11, $12, $13, $14, $15::jsonb, $16) + RETURNING id""", + document_id, + summary, + json.dumps(macro_themes), + novelty_score, + source_credibility, + json.dumps(extraction_warnings), + confidence, + model_provider, + model_name, + prompt_version, + schema_version, + raw_output_ref, + prompt_ref, + validation_status, + json.dumps(validation_errors or []), + retry_count, + ) + logger.debug("Persisted document intelligence %s for doc %s", intel_id, document_id) + return str(intel_id) + + +async def persist_document_impact( + pool: asyncpg.Pool, + *, + intelligence_id: str, + company_id: str, + ticker: str, + relevance: float, + sentiment: str, + impact_score: float, + impact_horizon: str, + catalyst_type: str, + key_facts: list[str], + risks: list[str], + evidence_spans: list[str], +) -> str: + """Persist a per-company impact record linked to a document intelligence row. + + Returns the impact record UUID. + + Requirements: 5.3, 5.5, 9.2 + """ + impact_id = await pool.fetchval( + """INSERT INTO document_impact_records + (intelligence_id, company_id, ticker, relevance, sentiment, + impact_score, impact_horizon, catalyst_type, + key_facts, risks, evidence_spans) + VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6, $7, $8, + $9::jsonb, $10::jsonb, $11::jsonb) + RETURNING id""", + intelligence_id, + company_id, + ticker, + relevance, + sentiment, + impact_score, + impact_horizon, + catalyst_type, + json.dumps(key_facts), + json.dumps(risks), + json.dumps(evidence_spans), + ) + logger.debug("Persisted impact record %s for %s", impact_id, ticker) + return str(impact_id) + + +async def update_document_status( + pool: asyncpg.Pool, + *, + document_id: str, + status: str, +) -> None: + """Update the status field on a document row. + + Used to advance documents through the pipeline: ingested → parsed → extracted → failed. + + Requirements: 5.4 + """ + await pool.execute( + """UPDATE documents SET status = $2, updated_at = NOW() WHERE id = $1::uuid""", + document_id, + status, + ) + logger.debug("Updated document %s status to %s", document_id, status) + + +async def reset_source_retry_state( + pool: asyncpg.Pool, + source_id: str, +) -> None: + """Reset retry state for a source after a successful run. + + Sets retry_count=0 and next_retry_at=NULL on the most recent run. + Called after a successful ingestion to clear any accumulated backoff. + """ + await pool.execute( + """UPDATE ingestion_runs + SET retry_count = 0, next_retry_at = NULL + WHERE id = ( + SELECT id FROM ingestion_runs + WHERE source_id = $1::uuid + ORDER BY started_at DESC + LIMIT 1 + )""", + source_id, + ) diff --git a/services/shared/metrics.py b/services/shared/metrics.py new file mode 100644 index 0000000..46d47da --- /dev/null +++ b/services/shared/metrics.py @@ -0,0 +1,317 @@ +"""Prometheus metrics for all Stonks Oracle pipeline stages. + +Provides counters, histograms, and gauges covering: +- Ingestion: items fetched, new items, errors, adapter latency +- Parsing: documents parsed, quality scores, low-quality flags +- Extraction: attempts, successes, failures, latency, confidence, retries +- Aggregation: trend windows computed, signal counts, contradiction scores +- Lake publication: facts published per table, write latency +- Trading: orders submitted, rejected, filled, risk evaluations + +Requirements: 12.1, 12.2 +Design: Section 12 (Observability and Operations) +""" +from __future__ import annotations + +from prometheus_client import Counter, Gauge, Histogram, Info + +# --------------------------------------------------------------------------- +# Service info +# --------------------------------------------------------------------------- + +SERVICE_INFO = Info("stonks_oracle", "Stonks Oracle service metadata") + +# --------------------------------------------------------------------------- +# Ingestion metrics +# --------------------------------------------------------------------------- + +INGESTION_JOBS_TOTAL = Counter( + "stonks_ingestion_jobs_total", + "Total ingestion jobs processed", + ["source_type", "status"], +) + +INGESTION_ITEMS_FETCHED = Counter( + "stonks_ingestion_items_fetched_total", + "Total items fetched from external sources", + ["source_type"], +) + +INGESTION_ITEMS_NEW = Counter( + "stonks_ingestion_items_new_total", + "New (non-duplicate) items ingested", + ["source_type"], +) + +INGESTION_ITEMS_DEDUPED = Counter( + "stonks_ingestion_items_deduped_total", + "Items skipped due to deduplication", + ["source_type"], +) + +INGESTION_ERRORS = Counter( + "stonks_ingestion_errors_total", + "Ingestion errors by source type", + ["source_type"], +) + +INGESTION_ADAPTER_DURATION = Histogram( + "stonks_ingestion_adapter_duration_seconds", + "Adapter fetch latency in seconds", + ["source_type"], + buckets=(0.1, 0.5, 1, 2, 5, 10, 30, 60), +) + +# --------------------------------------------------------------------------- +# Parsing metrics +# --------------------------------------------------------------------------- + +PARSE_JOBS_TOTAL = Counter( + "stonks_parse_jobs_total", + "Total parse jobs processed", + ["status"], +) + +PARSE_QUALITY_SCORE = Histogram( + "stonks_parse_quality_score", + "Distribution of parser quality scores", + buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0), +) + +PARSE_LOW_QUALITY_TOTAL = Counter( + "stonks_parse_low_quality_total", + "Documents flagged as low quality by the parser", +) + +PARSE_DURATION = Histogram( + "stonks_parse_duration_seconds", + "Parse job duration in seconds", + buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10), +) + +# --------------------------------------------------------------------------- +# Extraction metrics +# --------------------------------------------------------------------------- + +EXTRACTION_JOBS_TOTAL = Counter( + "stonks_extraction_jobs_total", + "Total extraction jobs processed", + ["status"], +) + +EXTRACTION_ATTEMPTS = Counter( + "stonks_extraction_attempts_total", + "Total Ollama extraction attempts (including retries)", +) + +EXTRACTION_RETRIES = Counter( + "stonks_extraction_retries_total", + "Extraction retry count", +) + +EXTRACTION_DURATION = Histogram( + "stonks_extraction_duration_seconds", + "Extraction total duration in seconds", + buckets=(1, 2, 5, 10, 20, 30, 60, 120), +) + +EXTRACTION_CONFIDENCE = Histogram( + "stonks_extraction_confidence", + "Distribution of extraction confidence scores", + buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0), +) + +EXTRACTION_VALIDATION_ERRORS = Counter( + "stonks_extraction_validation_errors_total", + "Total validation errors across extractions", +) + +EXTRACTION_TOKEN_ESTIMATE = Counter( + "stonks_extraction_tokens_total", + "Estimated token usage", + ["direction"], +) + +# --------------------------------------------------------------------------- +# Aggregation metrics +# --------------------------------------------------------------------------- + +AGGREGATION_WINDOWS_COMPUTED = Counter( + "stonks_aggregation_windows_total", + "Trend windows computed", + ["window"], +) + +AGGREGATION_SIGNALS_PROCESSED = Counter( + "stonks_aggregation_signals_total", + "Signals processed during aggregation", + ["window"], +) + +AGGREGATION_CONTRADICTION_SCORE = Histogram( + "stonks_aggregation_contradiction_score", + "Distribution of contradiction scores in trend windows", + buckets=(0.0, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0), +) + +AGGREGATION_DURATION = Histogram( + "stonks_aggregation_duration_seconds", + "Aggregation job duration in seconds", + ["window"], + buckets=(0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10), +) + +# --------------------------------------------------------------------------- +# Recommendation metrics +# --------------------------------------------------------------------------- + +RECOMMENDATION_GENERATED = Counter( + "stonks_recommendations_total", + "Recommendations generated", + ["action", "mode"], +) + +RECOMMENDATION_SUPPRESSED = Counter( + "stonks_recommendations_suppressed_total", + "Recommendations suppressed due to low data quality", +) + +RECOMMENDATION_CONFIDENCE = Histogram( + "stonks_recommendation_confidence", + "Distribution of recommendation confidence scores", + buckets=(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0), +) + +# --------------------------------------------------------------------------- +# Lake publication metrics +# --------------------------------------------------------------------------- + +LAKE_FACTS_PUBLISHED = Counter( + "stonks_lake_facts_published_total", + "Analytical facts published to the lakehouse", + ["table_name"], +) + +LAKE_PUBLISH_DURATION = Histogram( + "stonks_lake_publish_duration_seconds", + "Lake publication write latency in seconds", + ["table_name"], + buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5), +) + +LAKE_PUBLISH_ERRORS = Counter( + "stonks_lake_publish_errors_total", + "Lake publication errors", + ["table_name"], +) + +LAKE_PUBLISH_BYTES = Counter( + "stonks_lake_publish_bytes_total", + "Total bytes written to the lakehouse", + ["table_name"], +) + +# --------------------------------------------------------------------------- +# Trading / broker metrics +# --------------------------------------------------------------------------- + +ORDERS_SUBMITTED = Counter( + "stonks_orders_submitted_total", + "Orders submitted to broker", + ["side", "order_type", "mode"], +) + +ORDERS_REJECTED = Counter( + "stonks_orders_rejected_total", + "Orders rejected before broker submission", + ["reason_category"], +) + +ORDERS_FILLED = Counter( + "stonks_orders_filled_total", + "Orders filled by broker", + ["side"], +) + +ORDERS_DUPLICATES_PREVENTED = Counter( + "stonks_orders_duplicates_prevented_total", + "Duplicate orders prevented by idempotency checks", + ["detected_via"], +) + +RISK_EVALUATIONS_TOTAL = Counter( + "stonks_risk_evaluations_total", + "Risk evaluations performed", + ["result"], +) + +RISK_CHECK_FAILURES = Counter( + "stonks_risk_check_failures_total", + "Individual risk check failures", + ["check_name"], +) + +POSITIONS_SYNCED = Counter( + "stonks_positions_synced_total", + "Position sync operations completed", +) + +# --------------------------------------------------------------------------- +# Active gauges +# --------------------------------------------------------------------------- + +ACTIVE_JOBS = Gauge( + "stonks_active_jobs", + "Currently processing jobs by stage", + ["stage"], +) + +# --------------------------------------------------------------------------- +# Alerting metrics +# --------------------------------------------------------------------------- + +ALERTS_FIRED = Counter( + "stonks_alerts_fired_total", + "Total alerts fired by rule", + ["rule", "severity"], +) + +ALERTS_RESOLVED = Counter( + "stonks_alerts_resolved_total", + "Total alerts resolved by rule", + ["rule"], +) + +ALERT_CHECK_DURATION = Histogram( + "stonks_alert_check_duration_seconds", + "Duration of alert evaluation cycle", + buckets=(0.01, 0.05, 0.1, 0.25, 0.5, 1, 2, 5), +) + +ALERT_ACTIVE = Gauge( + "stonks_alert_active", + "Whether an alert rule is currently firing (1) or resolved (0)", + ["rule"], +) + +# --------------------------------------------------------------------------- +# Dead-letter queue metrics +# --------------------------------------------------------------------------- + +DLQ_ITEMS_TOTAL = Counter( + "stonks_dlq_items_total", + "Jobs sent to dead-letter queues", + ["queue"], +) + +DLQ_REPLAYED_TOTAL = Counter( + "stonks_dlq_replayed_total", + "Jobs replayed from dead-letter queues", + ["queue"], +) + +DLQ_DEPTH = Gauge( + "stonks_dlq_depth", + "Current dead-letter queue depth", + ["queue"], +) diff --git a/services/shared/redis_keys.py b/services/shared/redis_keys.py index 134bf89..9b96b39 100644 --- a/services/shared/redis_keys.py +++ b/services/shared/redis_keys.py @@ -46,6 +46,15 @@ def retry_key(job_id: str) -> str: return f"{RETRY_PREFIX}:{job_id}" +# Dead-letter queues +DLQ_PREFIX = f"{PREFIX}:dlq" + + +def dlq_key(queue_name: str) -> str: + """Return the dead-letter queue key for a given source queue.""" + return f"{DLQ_PREFIX}:{queue_name}" + + # --- Queue names --- QUEUE_INGESTION = "ingestion" QUEUE_PARSING = "parsing" @@ -54,3 +63,4 @@ QUEUE_AGGREGATION = "aggregation" QUEUE_RECOMMENDATION = "recommendation" QUEUE_LAKE_PUBLISH = "lake_publish" QUEUE_TRADE = "trade" +QUEUE_BROKER = "broker_orders" diff --git a/services/shared/retention.py b/services/shared/retention.py new file mode 100644 index 0000000..7b4afc1 --- /dev/null +++ b/services/shared/retention.py @@ -0,0 +1,306 @@ +"""Data retention and lifecycle controls for raw and derived artifacts. + +Provides configurable per-bucket retention policies, expired object cleanup +from MinIO, and expired metadata cleanup from PostgreSQL. + +Requirements: N3 (preserve source metadata, access policy, and retention policy) +Design ref: Section 5.2 (MinIO bucket layout), Section 10 (Reliability and Safety) +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone + +import asyncpg +from minio import Minio + +from services.shared.config import BUCKET_RETENTION_FIELDS, RetentionConfig +from services.shared.storage import ALL_BUCKETS + +logger = logging.getLogger("retention") + + +@dataclass +class RetentionPolicy: + """Resolved retention policy for a single bucket.""" + bucket_name: str + retention_days: int + archive_before_delete: bool = False + + +@dataclass +class CleanupResult: + """Result of a single bucket cleanup run.""" + bucket_name: str + objects_scanned: int = 0 + objects_deleted: int = 0 + bytes_freed: int = 0 + db_rows_deleted: int = 0 + + +def default_retention_days(bucket: str, config: RetentionConfig) -> int: + """Get the default retention days for a bucket from config.""" + field_name = BUCKET_RETENTION_FIELDS.get(bucket) + if field_name: + return getattr(config, field_name, 365) + return 365 + + +def resolve_policies(config: RetentionConfig) -> list[RetentionPolicy]: + """Build retention policies for all known buckets from config defaults.""" + return [ + RetentionPolicy( + bucket_name=bucket, + retention_days=default_retention_days(bucket, config), + ) + for bucket in ALL_BUCKETS + ] + + +async def load_db_policies(pool: asyncpg.Pool) -> dict[str, RetentionPolicy]: + """Load retention policy overrides from the database. + + Returns a dict keyed by bucket_name. DB policies take precedence + over config defaults when active. + """ + rows = await pool.fetch( + """SELECT bucket_name, retention_days, archive_before_delete + FROM retention_policies + WHERE active = TRUE AND artifact_class = 'default'""" + ) + return { + row["bucket_name"]: RetentionPolicy( + bucket_name=row["bucket_name"], + retention_days=row["retention_days"], + archive_before_delete=row["archive_before_delete"], + ) + for row in rows + } + + +def merge_policies( + config_policies: list[RetentionPolicy], + db_policies: dict[str, RetentionPolicy], +) -> list[RetentionPolicy]: + """Merge config defaults with DB overrides. DB wins on conflict.""" + merged: list[RetentionPolicy] = [] + for policy in config_policies: + if policy.bucket_name in db_policies: + merged.append(db_policies[policy.bucket_name]) + else: + merged.append(policy) + return merged + + +def cutoff_date(retention_days: int, now: datetime | None = None) -> datetime: + """Calculate the cutoff datetime. Objects older than this are expired.""" + ref = now or datetime.now(timezone.utc) + return ref - timedelta(days=retention_days) + + +def list_expired_objects( + client: Minio, + bucket: str, + retention_days: int, + batch_size: int = 1000, + now: datetime | None = None, +) -> list[str]: + """List object names in a bucket that are older than the retention cutoff. + + Uses the object's last_modified timestamp from MinIO metadata. + Returns at most batch_size object names. + """ + cutoff = cutoff_date(retention_days, now) + expired: list[str] = [] + + try: + objects = client.list_objects(bucket, recursive=True) + for obj in objects: + if obj.last_modified and obj.last_modified < cutoff: + if obj.object_name: + expired.append(obj.object_name) + if len(expired) >= batch_size: + break + except Exception: + logger.exception("Error listing objects in bucket %s", bucket) + + return expired + + +def delete_expired_objects( + client: Minio, + bucket: str, + object_names: list[str], +) -> int: + """Delete a list of objects from a MinIO bucket. + + Returns the count of successfully deleted objects. + """ + deleted = 0 + for name in object_names: + try: + client.remove_object(bucket, name) + deleted += 1 + except Exception: + logger.warning("Failed to delete %s/%s", bucket, name, exc_info=True) + return deleted + + +def cleanup_bucket( + client: Minio, + policy: RetentionPolicy, + batch_size: int = 1000, + now: datetime | None = None, +) -> CleanupResult: + """Run retention cleanup for a single bucket. + + Lists expired objects and deletes them in batches. + Returns a CleanupResult with counts. + """ + result = CleanupResult(bucket_name=policy.bucket_name) + + expired = list_expired_objects( + client, policy.bucket_name, policy.retention_days, + batch_size=batch_size, now=now, + ) + result.objects_scanned = len(expired) + + if expired: + result.objects_deleted = delete_expired_objects(client, policy.bucket_name, expired) + logger.info( + "Bucket %s: scanned=%d deleted=%d (retention=%dd)", + policy.bucket_name, result.objects_scanned, + result.objects_deleted, policy.retention_days, + ) + else: + logger.debug("Bucket %s: no expired objects (retention=%dd)", + policy.bucket_name, policy.retention_days) + + return result + + +# --- PostgreSQL metadata cleanup --- + +# Tables with a created_at or retrieved_at column that should be cleaned up +# when the corresponding MinIO artifacts are expired. +DB_CLEANUP_QUERIES: list[tuple[str, str]] = [ + ( + "ingestion_runs", + "DELETE FROM ingestion_runs WHERE started_at < $1", + ), + ( + "market_snapshots", + "DELETE FROM market_snapshots WHERE captured_at < $1", + ), +] + + +async def cleanup_expired_db_records( + pool: asyncpg.Pool, + retention_days: int, + now: datetime | None = None, +) -> int: + """Delete expired operational metadata from PostgreSQL. + + Uses the shortest raw retention period to clean up ingestion tracking + and market snapshot records that are past their useful life. + + Returns total rows deleted. + """ + cutoff = cutoff_date(retention_days, now) + total_deleted = 0 + + async with pool.acquire() as conn: + for table_name, query in DB_CLEANUP_QUERIES: + try: + result = await conn.execute(query, cutoff) + # asyncpg returns "DELETE N" + count = int(result.split()[-1]) if result else 0 + total_deleted += count + if count > 0: + logger.info("Cleaned %d expired rows from %s (cutoff=%s)", + count, table_name, cutoff.isoformat()) + except Exception: + logger.exception("Error cleaning table %s", table_name) + + return total_deleted + + +async def record_retention_run( + pool: asyncpg.Pool, + bucket_name: str, + result: CleanupResult, + status: str = "completed", + error_message: str | None = None, +) -> None: + """Record a retention cleanup run in the retention_runs table.""" + await pool.execute( + """INSERT INTO retention_runs + (bucket_name, objects_scanned, objects_deleted, bytes_freed, + db_rows_deleted, completed_at, status, error_message) + VALUES ($1, $2, $3, $4, $5, NOW(), $6, $7)""", + bucket_name, + result.objects_scanned, + result.objects_deleted, + result.bytes_freed, + result.db_rows_deleted, + status, + error_message, + ) + + +async def run_retention_cleanup( + minio_client: Minio, + pool: asyncpg.Pool, + config: RetentionConfig, + now: datetime | None = None, +) -> list[CleanupResult]: + """Run the full retention cleanup cycle. + + 1. Resolve policies from config defaults + DB overrides + 2. Clean up expired MinIO objects per bucket + 3. Clean up expired PostgreSQL metadata + 4. Record each run for observability + + Returns a list of CleanupResult for each bucket processed. + """ + # Resolve policies + config_policies = resolve_policies(config) + try: + db_policies = await load_db_policies(pool) + except Exception: + logger.warning("Could not load DB retention policies, using config defaults") + db_policies = {} + + policies = merge_policies(config_policies, db_policies) + results: list[CleanupResult] = [] + + # Clean up MinIO objects per bucket + for policy in policies: + try: + result = cleanup_bucket( + minio_client, policy, + batch_size=config.batch_size, now=now, + ) + results.append(result) + await record_retention_run(pool, policy.bucket_name, result) + except Exception: + logger.exception("Retention cleanup failed for bucket %s", policy.bucket_name) + empty = CleanupResult(bucket_name=policy.bucket_name) + await record_retention_run( + pool, policy.bucket_name, empty, + status="failed", error_message="See logs", + ) + results.append(empty) + + # Clean up expired DB records using the shortest raw retention period + min_retention = min(p.retention_days for p in policies) + try: + db_deleted = await cleanup_expired_db_records(pool, min_retention, now=now) + if db_deleted > 0: + logger.info("Total DB rows cleaned: %d", db_deleted) + except Exception: + logger.exception("DB retention cleanup failed") + + return results diff --git a/services/shared/schemas.py b/services/shared/schemas.py index e2b4d9a..8d3f411 100644 --- a/services/shared/schemas.py +++ b/services/shared/schemas.py @@ -108,6 +108,41 @@ class DocumentIntelligence(BaseModel): # --- Trend Summary --- +class MarketContext(BaseModel): + """Recent market data features for a symbol, used to enrich aggregation.""" + + ticker: str = "" + price_change_pct: Optional[float] = None # % change over the window + avg_volume: Optional[float] = None # average daily volume + volume_change_pct: Optional[float] = None # volume vs prior period + volatility: Optional[float] = None # intra-window price std dev + latest_close: Optional[float] = None + latest_bar_at: Optional[datetime] = None + bars_available: int = 0 + + @property + def has_data(self) -> bool: + return self.bars_available > 0 + + +class DisagreementDetail(BaseModel): + """Represents an explicit disagreement between document signals. + + Rather than collapsing contradictory signals into a single score, + this captures the nature of the disagreement so downstream consumers + can inspect *why* signals conflict. + + Requirements: 6.4 + """ + + dimension: str = "" # e.g. "sentiment", "catalyst", "impact_horizon" + positive_doc_ids: List[str] = Field(default_factory=list) + negative_doc_ids: List[str] = Field(default_factory=list) + positive_weight: float = 0.0 + negative_weight: float = 0.0 + description: str = "" + + class TrendSummary(BaseModel): entity_type: str = "company" entity_id: str = "" @@ -120,6 +155,8 @@ class TrendSummary(BaseModel): dominant_catalysts: List[str] = Field(default_factory=list) material_risks: List[str] = Field(default_factory=list) contradiction_score: float = Field(ge=0, le=1, default=0.0) + disagreement_details: List[DisagreementDetail] = Field(default_factory=list) + market_context: Optional[MarketContext] = None generated_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/services/shared/storage.py b/services/shared/storage.py new file mode 100644 index 0000000..2de353a --- /dev/null +++ b/services/shared/storage.py @@ -0,0 +1,352 @@ +"""Raw artifact upload to MinIO. + +Provides a reusable storage layer for uploading raw artifacts (API payloads, +HTML, normalized text, model outputs) to MinIO with consistent path conventions, +bucket management, and content-type handling. + +Bucket layout follows the design spec: + - stonks-raw-market — raw market API payloads + - stonks-raw-news — raw news API payloads and article HTML + - stonks-raw-filings — raw filings and issuer event payloads + - stonks-normalized — cleaned text and parser outputs + - stonks-llm-prompts — prompts and schemas used + - stonks-llm-results — raw model outputs and validation reports + - stonks-lakehouse — partitioned analytical datasets and table metadata + - stonks-audit — execution traces and exported reports + +Object path pattern: + /{stage}/{symbol}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_type}.{ext} + +Requirements: 3.1, 3.2, 3.3, 9.1 +""" +import io +import logging +from datetime import datetime, timezone +from typing import Mapping + +from minio import Minio +from minio.error import S3Error + +logger = logging.getLogger("storage") + +# All known buckets the platform uses +ALL_BUCKETS = [ + "stonks-raw-market", + "stonks-raw-news", + "stonks-raw-filings", + "stonks-normalized", + "stonks-llm-prompts", + "stonks-llm-results", + "stonks-lakehouse", + "stonks-audit", +] + +# Map source_type to the correct raw bucket +SOURCE_BUCKET_MAP: dict[str, str] = { + "market_api": "stonks-raw-market", + "news_api": "stonks-raw-news", + "filings_api": "stonks-raw-filings", + "web_scrape": "stonks-raw-news", + "broker": "stonks-raw-market", +} + +# Map artifact type to content type and file extension +ARTIFACT_CONTENT_TYPES: dict[str, tuple[str, str]] = { + "raw_json": ("application/json", "json"), + "raw_html": ("text/html", "html"), + "raw_text": ("text/plain", "txt"), + "raw_payload": ("application/octet-stream", "bin"), +} + + +def bucket_for_source(source_type: str) -> str: + """Return the MinIO bucket name for a given source type.""" + return SOURCE_BUCKET_MAP.get(source_type, "stonks-raw-market") + + +def build_artifact_path( + source_type: str, + ticker: str, + document_id: str, + artifact_name: str = "raw", + ext: str = "json", + timestamp: datetime | None = None, +) -> str: + """Build a MinIO object path following the design convention. + + Pattern: {source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext} + """ + ts = timestamp or datetime.now(timezone.utc) + return ( + f"{source_type}/{ticker}/" + f"{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/{artifact_name}.{ext}" + ) + + +def storage_ref(bucket: str, path: str) -> str: + """Build an s3:// URI for a stored artifact.""" + return f"s3://{bucket}/{path}" + + +def ensure_buckets(client: Minio, buckets: list[str] | None = None) -> list[str]: + """Create any missing buckets. Returns list of buckets that were created.""" + target_buckets = buckets or ALL_BUCKETS + created: list[str] = [] + for bucket in target_buckets: + try: + if not client.bucket_exists(bucket): + client.make_bucket(bucket) + created.append(bucket) + logger.info("Created bucket: %s", bucket) + except S3Error as e: + logger.error("Failed to ensure bucket %s: %s", bucket, e) + raise + return created + + +def upload_artifact( + client: Minio, + bucket: str, + path: str, + data: bytes, + content_type: str = "application/json", + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload raw bytes to MinIO and return the s3:// storage reference. + + Args: + client: MinIO client instance. + bucket: Target bucket name. + path: Object path within the bucket. + data: Raw bytes to upload. + content_type: MIME type for the object. + metadata: Optional user metadata to attach to the object. + + Returns: + s3:// URI pointing to the uploaded object. + """ + _result = client.put_object( + bucket, + path, + io.BytesIO(data), + length=len(data), + content_type=content_type, + metadata=metadata, + ) + ref = storage_ref(bucket, path) + logger.debug("Uploaded %d bytes to %s", len(data), ref) + return ref + + +def upload_raw_artifact( + client: Minio, + source_type: str, + ticker: str, + document_id: str, + data: bytes, + artifact_type: str = "raw_json", + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload a raw artifact using standard conventions for bucket, path, and content type. + + This is the primary entry point for ingestion workers to store raw payloads. + + Args: + client: MinIO client instance. + source_type: One of market_api, news_api, filings_api, web_scrape, broker. + ticker: Company ticker symbol. + document_id: Unique document or run identifier. + data: Raw bytes to upload. + artifact_type: One of raw_json, raw_html, raw_text, raw_payload. + timestamp: Override timestamp for path generation (defaults to now UTC). + metadata: Optional user metadata dict. + + Returns: + s3:// URI pointing to the uploaded object. + """ + bucket = bucket_for_source(source_type) + ct, ext = ARTIFACT_CONTENT_TYPES.get(artifact_type, ("application/octet-stream", "bin")) + path = build_artifact_path( + source_type=source_type, + ticker=ticker, + document_id=document_id, + artifact_name="raw", + ext=ext, + timestamp=timestamp, + ) + return upload_artifact(client, bucket, path, data, content_type=ct, metadata=metadata) + + +def upload_html_artifact( + client: Minio, + ticker: str, + document_id: str, + html_bytes: bytes, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload raw HTML for a scraped web page. + + Stores in stonks-raw-news under the web_scrape source path. + """ + bucket = bucket_for_source("web_scrape") + path = build_artifact_path( + source_type="web_scrape", + ticker=ticker, + document_id=document_id, + artifact_name="raw", + ext="html", + timestamp=timestamp, + ) + return upload_artifact(client, bucket, path, html_bytes, content_type="text/html", metadata=metadata) + + +def upload_normalized_text( + client: Minio, + ticker: str, + document_id: str, + text_bytes: bytes, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload normalized (parsed) text to the stonks-normalized bucket. + + Stores under parsed/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/normalized.txt + """ + ts = timestamp or datetime.now(timezone.utc) + path = ( + f"parsed/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/normalized.txt" + ) + return upload_artifact( + client, "stonks-normalized", path, text_bytes, + content_type="text/plain", metadata=metadata, + ) + + +def upload_parser_output( + client: Minio, + ticker: str, + document_id: str, + output_bytes: bytes, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload structured parser output JSON to the stonks-normalized bucket. + + Stores under parsed/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/parser_output.json + """ + ts = timestamp or datetime.now(timezone.utc) + path = ( + f"parsed/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/parser_output.json" + ) + return upload_artifact( + client, "stonks-normalized", path, output_bytes, + content_type="application/json", metadata=metadata, + ) + + +def upload_extraction_prompt( + client: Minio, + ticker: str, + document_id: str, + prompt_data: bytes, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload the extraction prompt and schema to stonks-llm-prompts. + + Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/prompt.json + """ + ts = timestamp or datetime.now(timezone.utc) + path = ( + f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/prompt.json" + ) + return upload_artifact( + client, "stonks-llm-prompts", path, prompt_data, + content_type="application/json", metadata=metadata, + ) + + +def upload_extraction_raw_output( + client: Minio, + ticker: str, + document_id: str, + output_data: bytes, + attempt_index: int = 0, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload a raw model output to stonks-llm-results. + + Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/raw_output_{attempt}.json + """ + ts = timestamp or datetime.now(timezone.utc) + path = ( + f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/raw_output_{attempt_index}.json" + ) + return upload_artifact( + client, "stonks-llm-results", path, output_data, + content_type="application/json", metadata=metadata, + ) + + +def upload_extraction_validation( + client: Minio, + ticker: str, + document_id: str, + validation_data: bytes, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload a validation report to stonks-llm-results. + + Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/validation.json + """ + ts = timestamp or datetime.now(timezone.utc) + path = ( + f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/validation.json" + ) + return upload_artifact( + client, "stonks-llm-results", path, validation_data, + content_type="application/json", metadata=metadata, + ) + + +def upload_extraction_intelligence( + client: Minio, + ticker: str, + document_id: str, + intelligence_data: bytes, + timestamp: datetime | None = None, + metadata: Mapping[str, str] | None = None, +) -> str: + """Upload the final intelligence object to stonks-llm-results. + + Stores under extraction/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/intelligence.json + """ + ts = timestamp or datetime.now(timezone.utc) + path = ( + f"extraction/{ticker}/{ts.year}/{ts.month:02d}/{ts.day:02d}/" + f"{document_id}/intelligence.json" + ) + return upload_artifact( + client, "stonks-llm-results", path, intelligence_data, + content_type="application/json", metadata=metadata, + ) + + +def download_artifact(client: Minio, bucket: str, path: str) -> bytes: + """Download an artifact from MinIO and return its bytes.""" + response = client.get_object(bucket, path) + try: + return response.read() + finally: + response.close() + response.release_conn() diff --git a/services/symbol_registry/app.py b/services/symbol_registry/app.py index 292a4b6..637b564 100644 --- a/services/symbol_registry/app.py +++ b/services/symbol_registry/app.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, field_validator from services.shared.config import load_config from services.shared.db import get_pg_pool +from services.shared.logging import setup_logging config = load_config() pool: Optional[asyncpg.Pool] = None @@ -18,6 +19,7 @@ pool: Optional[asyncpg.Pool] = None @asynccontextmanager async def lifespan(app: FastAPI): global pool + setup_logging("symbol_registry", level=config.log_level, json_output=config.json_logs) pool = await get_pg_pool(config) yield await pool.close() diff --git a/services/symbol_registry/seed.py b/services/symbol_registry/seed.py index e5044e9..2d173ce 100644 --- a/services/symbol_registry/seed.py +++ b/services/symbol_registry/seed.py @@ -13,8 +13,8 @@ import asyncpg from services.shared.config import load_config from services.shared.db import get_pg_pool +from services.shared.logging import setup_logging -logging.basicConfig(level=logging.INFO) logger = logging.getLogger("seed") # --- Seed Companies --- @@ -173,6 +173,7 @@ async def seed(pool: asyncpg.Pool) -> None: async def main() -> None: config = load_config() + setup_logging("seed", level=config.log_level, json_output=config.json_logs) pool = await get_pg_pool(config) try: await seed(pool) diff --git a/tests/replay_fixtures/README.md b/tests/replay_fixtures/README.md new file mode 100644 index 0000000..29cd90f --- /dev/null +++ b/tests/replay_fixtures/README.md @@ -0,0 +1,16 @@ +# Replay Dataset for Deterministic Extraction Testing + +Archived document fixtures used to verify that the extraction pipeline +produces consistent, schema-valid results across code changes. + +Each fixture is a JSON file containing: +- `document_id`: stable identifier for the fixture +- `document_type`: article, filing, transcript, or press_release +- `document_text`: normalized text as it would arrive from the parser +- `known_tickers`: ticker hints passed to the extraction prompt +- `expected_extraction`: the expected extraction result (schema-valid) +- `metadata`: fixture provenance info (created_at, description, schema_version) + +The replay runner (`tests/test_replay_extraction.py`) loads these fixtures, +validates the expected outputs against the current extraction schema, and +optionally runs them through a live Ollama instance for end-to-end checks. diff --git a/tests/replay_fixtures/aapl_earnings_article.json b/tests/replay_fixtures/aapl_earnings_article.json new file mode 100644 index 0000000..a20b663 --- /dev/null +++ b/tests/replay_fixtures/aapl_earnings_article.json @@ -0,0 +1,45 @@ +{ + "document_id": "replay-001-aapl-earnings", + "document_type": "article", + "document_text": "Apple Inc. reported fiscal Q1 2026 results that exceeded Wall Street expectations. Revenue came in at $124.3 billion, up 9% year-over-year, driven by strong iPhone 17 demand and a 22% surge in Services revenue to $26.1 billion. CEO Tim Cook highlighted record-setting performance in emerging markets, particularly India and Southeast Asia.\n\nEarnings per share of $2.41 beat the consensus estimate of $2.35. Gross margin expanded to 46.9%, up from 45.9% a year ago, reflecting favorable product mix and supply chain efficiencies.\n\nAnalysts at Morgan Stanley raised their price target to $245 from $230, citing the Services growth trajectory. However, Greater China revenue declined 4% amid increased competition from Huawei, which Cook acknowledged as a headwind.\n\nApple also announced a $110 billion share buyback program, the largest in corporate history, and raised its quarterly dividend by 5% to $0.26 per share.", + "known_tickers": ["AAPL"], + "expected_extraction": { + "summary": "Apple reported Q1 2026 results beating expectations with $124.3B revenue up 9% YoY, driven by iPhone 17 demand and 22% Services growth, though China revenue declined 4%.", + "companies": [ + { + "ticker": "AAPL", + "company_name": "Apple Inc.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.8, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": [ + "Revenue $124.3 billion, up 9% year-over-year", + "EPS $2.41 beat consensus of $2.35", + "Services revenue surged 22% to $26.1 billion", + "Greater China revenue declined 4%", + "$110 billion share buyback announced" + ], + "risks": [ + "Greater China revenue declined 4% amid Huawei competition" + ], + "evidence_spans": [ + "Revenue came in at $124.3 billion, up 9% year-over-year", + "Earnings per share of $2.41 beat the consensus estimate of $2.35", + "Greater China revenue declined 4% amid increased competition from Huawei" + ] + } + ], + "macro_themes": [], + "novelty_score": 0.5, + "confidence": 0.9, + "extraction_warnings": [] + }, + "metadata": { + "created_at": "2026-04-11", + "description": "Synthetic Apple earnings article for replay testing", + "schema_version": "2.0.0", + "category": "earnings_beat" + } +} diff --git a/tests/replay_fixtures/low_quality_garbled.json b/tests/replay_fixtures/low_quality_garbled.json new file mode 100644 index 0000000..8ca8b02 --- /dev/null +++ b/tests/replay_fixtures/low_quality_garbled.json @@ -0,0 +1,20 @@ +{ + "document_id": "replay-004-low-quality", + "document_type": "article", + "document_text": "Error 403 Forbidden. Access denied. Please subscribe to continue reading. Cookie preferences updated. Share on Twitter. Share on Facebook.", + "known_tickers": ["AAPL"], + "expected_extraction": { + "summary": "", + "companies": [], + "macro_themes": [], + "novelty_score": 0.1, + "confidence": 0.1, + "extraction_warnings": ["insufficient_content"] + }, + "metadata": { + "created_at": "2026-04-11", + "description": "Garbled/paywall document that should produce empty extraction with low confidence (Req 4.3, 5.4)", + "schema_version": "2.0.0", + "category": "low_quality" + } +} diff --git a/tests/replay_fixtures/msft_press_release.json b/tests/replay_fixtures/msft_press_release.json new file mode 100644 index 0000000..fc5c9db --- /dev/null +++ b/tests/replay_fixtures/msft_press_release.json @@ -0,0 +1,44 @@ +{ + "document_id": "replay-005-msft-press-release", + "document_type": "press_release", + "document_text": "REDMOND, Wash. — April 8, 2026 — Microsoft Corp. today announced it has entered into a definitive agreement to acquire Nuance Communications, Inc. subsidiary DataSphere AI for approximately $4.2 billion in an all-cash transaction. The acquisition is expected to close in Q3 2026, subject to regulatory approval.\n\nDataSphere AI specializes in healthcare-specific large language models and clinical decision support systems deployed across 1,200 hospitals in the United States. The acquisition will strengthen Microsoft's Azure Health Cloud platform and expand its presence in the $280 billion global healthcare IT market.\n\nSatya Nadella, Chairman and CEO of Microsoft, said: \"DataSphere AI's clinical language models are the most advanced in the industry. This acquisition accelerates our mission to empower every healthcare organization with AI that improves patient outcomes.\"\n\nThe transaction is expected to be accretive to Microsoft's earnings per share within 18 months of closing. Microsoft plans to integrate DataSphere's technology into Azure AI services and the Microsoft Cloud for Healthcare platform.", + "known_tickers": ["MSFT"], + "expected_extraction": { + "summary": "Microsoft announced a $4.2 billion all-cash acquisition of DataSphere AI, a healthcare LLM company deployed in 1,200 U.S. hospitals, to strengthen Azure Health Cloud.", + "companies": [ + { + "ticker": "MSFT", + "company_name": "Microsoft Corp.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.7, + "impact_horizon": "1d_30d", + "catalyst_type": "m_and_a", + "key_facts": [ + "Acquiring DataSphere AI for $4.2 billion in all-cash transaction", + "Expected to close Q3 2026 subject to regulatory approval", + "DataSphere deployed across 1,200 hospitals in the United States", + "Expected to be accretive to EPS within 18 months" + ], + "risks": [ + "Subject to regulatory approval" + ], + "evidence_spans": [ + "entered into a definitive agreement to acquire Nuance Communications, Inc. subsidiary DataSphere AI for approximately $4.2 billion", + "deployed across 1,200 hospitals in the United States", + "expected to be accretive to Microsoft's earnings per share within 18 months of closing" + ] + } + ], + "macro_themes": ["ai_capex"], + "novelty_score": 0.75, + "confidence": 0.9, + "extraction_warnings": [] + }, + "metadata": { + "created_at": "2026-04-11", + "description": "Synthetic Microsoft M&A press release for replay testing", + "schema_version": "2.0.0", + "category": "press_release_m_and_a" + } +} diff --git a/tests/replay_fixtures/multi_company_article.json b/tests/replay_fixtures/multi_company_article.json new file mode 100644 index 0000000..0c0aa7f --- /dev/null +++ b/tests/replay_fixtures/multi_company_article.json @@ -0,0 +1,97 @@ +{ + "document_id": "replay-003-multi-company", + "document_type": "article", + "document_text": "The semiconductor sector faced a turbulent week as new U.S. export restrictions targeting advanced AI chips sent shockwaves through the industry. NVIDIA Corporation saw its shares drop 7% on Monday after the Commerce Department announced expanded controls on shipments of H200 and B100 GPUs to several Middle Eastern countries.\n\nAdvanced Micro Devices was also affected, declining 4.2%, though analysts noted AMD's exposure to the restricted markets is smaller than NVIDIA's. Bernstein analyst Stacy Rasgon estimated NVIDIA could lose $4-5 billion in annual revenue from the new restrictions, while AMD's impact would be closer to $800 million.\n\nMeanwhile, Taiwan Semiconductor Manufacturing Company reported that its advanced packaging capacity for AI chips remains fully booked through 2027, suggesting underlying demand remains robust despite the regulatory headwinds. TSMC shares rose 1.3% on the news.\n\nIntel Corporation, which has been positioning its Gaudi 3 accelerator as a domestic alternative, saw a modest 2.1% gain as investors speculated the restrictions could redirect demand toward U.S.-manufactured alternatives.", + "known_tickers": ["NVDA", "AMD", "TSM", "INTC"], + "expected_extraction": { + "summary": "New U.S. export restrictions on advanced AI chips hit NVIDIA (-7%) and AMD (-4.2%), while TSMC reported full AI packaging capacity through 2027 and Intel gained on domestic alternative positioning.", + "companies": [ + { + "ticker": "NVDA", + "company_name": "NVIDIA Corporation", + "relevance": 0.9, + "sentiment": "negative", + "impact_score": 0.8, + "impact_horizon": "1d_30d", + "catalyst_type": "macro", + "key_facts": [ + "Shares dropped 7% on expanded export controls", + "H200 and B100 GPUs targeted by new restrictions", + "Estimated $4-5 billion annual revenue loss from restrictions" + ], + "risks": [ + "Expanded U.S. export controls on AI chip shipments to Middle Eastern countries" + ], + "evidence_spans": [ + "NVIDIA Corporation saw its shares drop 7% on Monday after the Commerce Department announced expanded controls", + "NVIDIA could lose $4-5 billion in annual revenue from the new restrictions" + ] + }, + { + "ticker": "AMD", + "company_name": "Advanced Micro Devices", + "relevance": 0.7, + "sentiment": "negative", + "impact_score": 0.55, + "impact_horizon": "1d_30d", + "catalyst_type": "macro", + "key_facts": [ + "Shares declined 4.2%", + "Estimated $800 million annual revenue impact" + ], + "risks": [ + "Exposure to restricted export markets" + ], + "evidence_spans": [ + "Advanced Micro Devices was also affected, declining 4.2%", + "AMD's impact would be closer to $800 million" + ] + }, + { + "ticker": "TSM", + "company_name": "Taiwan Semiconductor Manufacturing Company", + "relevance": 0.65, + "sentiment": "positive", + "impact_score": 0.5, + "impact_horizon": "1d_7d", + "catalyst_type": "product", + "key_facts": [ + "Advanced packaging capacity for AI chips fully booked through 2027", + "Shares rose 1.3%" + ], + "risks": [], + "evidence_spans": [ + "advanced packaging capacity for AI chips remains fully booked through 2027", + "TSMC shares rose 1.3% on the news" + ] + }, + { + "ticker": "INTC", + "company_name": "Intel Corporation", + "relevance": 0.5, + "sentiment": "positive", + "impact_score": 0.35, + "impact_horizon": "1d_7d", + "catalyst_type": "macro", + "key_facts": [ + "Gaudi 3 accelerator positioned as domestic alternative", + "Shares gained 2.1%" + ], + "risks": [], + "evidence_spans": [ + "Intel Corporation, which has been positioning its Gaudi 3 accelerator as a domestic alternative, saw a modest 2.1% gain" + ] + } + ], + "macro_themes": ["ai_capex"], + "novelty_score": 0.7, + "confidence": 0.85, + "extraction_warnings": [] + }, + "metadata": { + "created_at": "2026-04-11", + "description": "Synthetic multi-company semiconductor article for replay testing (Req 5.5)", + "schema_version": "2.0.0", + "category": "multi_company" + } +} diff --git a/tests/replay_fixtures/tsla_sec_filing.json b/tests/replay_fixtures/tsla_sec_filing.json new file mode 100644 index 0000000..5fa12ac --- /dev/null +++ b/tests/replay_fixtures/tsla_sec_filing.json @@ -0,0 +1,45 @@ +{ + "document_id": "replay-002-tsla-filing", + "document_type": "filing", + "document_text": "UNITED STATES SECURITIES AND EXCHANGE COMMISSION\nWashington, D.C. 20549\nFORM 8-K\n\nCURRENT REPORT\nPursuant to Section 13 or 15(d) of the Securities Exchange Act of 1934\n\nDate of Report: March 28, 2026\n\nTESLA, INC.\n(Exact name of registrant as specified in its charter)\n\nItem 2.02 Results of Operations and Financial Condition.\n\nOn March 28, 2026, Tesla, Inc. issued a press release announcing its financial results for the fiscal quarter ended March 31, 2026. Total revenue was $25.8 billion, compared to $23.3 billion in the prior year quarter. Automotive revenue was $20.1 billion. Energy generation and storage revenue increased 67% to $3.2 billion.\n\nGAAP net income was $2.1 billion, or $0.61 per diluted share. Non-GAAP net income was $2.5 billion, or $0.73 per diluted share.\n\nThe Company disclosed that vehicle deliveries totaled 478,000 units, below the consensus estimate of 495,000 units. Management attributed the shortfall to production line retooling for the refreshed Model Y at the Fremont and Shanghai factories.\n\nRisk Factors: The Company noted ongoing regulatory uncertainty in the European Union regarding autonomous driving software certification, which could delay Full Self-Driving rollout in key markets. Additionally, lithium carbonate prices have increased 18% quarter-over-quarter, pressuring battery cell costs.", + "known_tickers": ["TSLA"], + "expected_extraction": { + "summary": "Tesla 8-K filing reports Q1 2026 results with $25.8B revenue, but vehicle deliveries of 478K missed consensus of 495K due to Model Y retooling. Energy segment grew 67%.", + "companies": [ + { + "ticker": "TSLA", + "company_name": "Tesla, Inc.", + "relevance": 0.95, + "sentiment": "mixed", + "impact_score": 0.75, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": [ + "Total revenue $25.8 billion vs $23.3 billion prior year", + "Vehicle deliveries 478,000 units, below consensus of 495,000", + "Energy generation and storage revenue increased 67% to $3.2 billion", + "GAAP net income $2.1 billion or $0.61 per diluted share" + ], + "risks": [ + "EU regulatory uncertainty regarding autonomous driving software certification", + "Lithium carbonate prices increased 18% quarter-over-quarter" + ], + "evidence_spans": [ + "Total revenue was $25.8 billion, compared to $23.3 billion in the prior year quarter", + "vehicle deliveries totaled 478,000 units, below the consensus estimate of 495,000 units", + "lithium carbonate prices have increased 18% quarter-over-quarter, pressuring battery cell costs" + ] + } + ], + "macro_themes": [], + "novelty_score": 0.45, + "confidence": 0.88, + "extraction_warnings": [] + }, + "metadata": { + "created_at": "2026-04-11", + "description": "Synthetic Tesla 8-K filing for replay testing", + "schema_version": "2.0.0", + "category": "sec_filing" + } +} diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 0000000..1e09255 --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,100 @@ +"""Tests for adapter base interface and result types.""" +from datetime import datetime + +from services.adapters.base import AdapterResult, BaseAdapter + + +class TestAdapterResult: + def test_ok_when_items_and_no_error(self): + r = AdapterResult( + source_type="market_api", + ticker="AAPL", + items=[{"price": 150}], + raw_payload=b'{"price":150}', + content_hash="abc123", + fetched_at=datetime(2026, 4, 11), + ) + assert r.ok is True + assert r.item_count == 1 + + def test_not_ok_when_error(self): + r = AdapterResult( + source_type="market_api", + ticker="AAPL", + items=[], + raw_payload=b"", + content_hash="", + fetched_at=datetime(2026, 4, 11), + error="timeout", + ) + assert r.ok is False + + def test_not_ok_when_empty_items(self): + r = AdapterResult( + source_type="news_api", + ticker="MSFT", + items=[], + raw_payload=b"{}", + content_hash="def456", + fetched_at=datetime(2026, 4, 11), + ) + assert r.ok is False + + def test_http_metadata_defaults(self): + r = AdapterResult( + source_type="market_api", + ticker="AAPL", + items=[{"x": 1}], + raw_payload=b"x", + content_hash="h", + fetched_at=datetime(2026, 4, 11), + ) + assert r.http_status is None + assert r.response_time_ms is None + assert r.metadata == {} + + +class _StubAdapter(BaseAdapter): + async def fetch(self, ticker, config): + return AdapterResult( + source_type="market_api", + ticker=ticker, + items=[], + raw_payload=b"", + content_hash="", + fetched_at=datetime(2026, 4, 11), + ) + + def source_type(self): + return "market_api" + + +class _FilingsStub(BaseAdapter): + async def fetch(self, ticker, config): + return AdapterResult( + source_type="filings_api", + ticker=ticker, + items=[], + raw_payload=b"", + content_hash="", + fetched_at=datetime(2026, 4, 11), + ) + + def source_type(self): + return "filings_api" + + +class TestBaseAdapterHelpers: + def test_bucket_name_market(self): + adapter = _StubAdapter() + assert adapter.bucket_name() == "stonks-raw-market" + + def test_bucket_name_filings(self): + adapter = _FilingsStub() + assert adapter.bucket_name() == "stonks-raw-filings" + + def test_artifact_path_format(self): + adapter = _StubAdapter() + now = datetime(2026, 4, 11, 14, 30) + path = adapter.artifact_path("AAPL", "doc-123", now) + assert path == "market_api/AAPL/2026/04/11/doc-123/raw.json" diff --git a/tests/test_aggregation_scoring.py b/tests/test_aggregation_scoring.py new file mode 100644 index 0000000..7681205 --- /dev/null +++ b/tests/test_aggregation_scoring.py @@ -0,0 +1,248 @@ +"""Tests for aggregation scoring — recency decay, source credibility weighting, +and market context integration.""" +from datetime import datetime, timedelta, timezone + +from services.aggregation.scoring import ( + DEFAULT_CONFIG, + ScoringConfig, + WeightedSignal, + compute_signal_weight, + credibility_weight, + market_context_multiplier, + recency_weight, + sentiment_to_numeric, + weighted_sentiment_average, +) +from services.shared.schemas import MarketContext + + +# --------------------------------------------------------------------------- +# recency_weight +# --------------------------------------------------------------------------- + + +def test_recency_weight_at_zero_age(): + """A document published exactly at reference time gets weight 1.0.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + assert recency_weight(now, now, "7d") == 1.0 + + +def test_recency_weight_future_document(): + """A document published after reference time is clamped to 1.0.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + future = now + timedelta(hours=1) + assert recency_weight(future, now, "7d") == 1.0 + + +def test_recency_weight_at_one_half_life(): + """After exactly one half-life the weight should be ~0.5.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + half_life_7d = DEFAULT_CONFIG.half_life_hours["7d"] # 72 hours + published = now - timedelta(hours=half_life_7d) + w = recency_weight(published, now, "7d") + assert abs(w - 0.5) < 1e-9 + + +def test_recency_weight_very_old_clamps_to_min(): + """A very old document should not go below min_recency_weight.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + ancient = now - timedelta(days=365) + w = recency_weight(ancient, now, "7d") + assert w == DEFAULT_CONFIG.min_recency_weight + + +def test_recency_weight_different_windows(): + """Shorter windows decay faster than longer ones.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + published = now - timedelta(hours=24) + w_intraday = recency_weight(published, now, "intraday") + w_90d = recency_weight(published, now, "90d") + assert w_intraday < w_90d + + +def test_recency_weight_naive_datetimes(): + """Naive datetimes are treated as UTC.""" + now = datetime(2026, 4, 11, 12, 0, 0) + published = now - timedelta(hours=72) + w = recency_weight(published, now, "7d") + assert abs(w - 0.5) < 1e-9 + + +# --------------------------------------------------------------------------- +# credibility_weight +# --------------------------------------------------------------------------- + + +def test_credibility_weight_high(): + """High credibility source gets weight close to 1.0.""" + assert abs(credibility_weight(1.0) - 1.0) < 1e-9 + + +def test_credibility_weight_low_clamped(): + """Credibility below floor is clamped to floor.""" + w = credibility_weight(0.0) + assert abs(w - DEFAULT_CONFIG.credibility_floor) < 1e-9 + + +def test_credibility_weight_mid(): + """Mid-range credibility passes through with exponent=1.""" + assert abs(credibility_weight(0.5) - 0.5) < 1e-9 + + +def test_credibility_weight_custom_exponent(): + """Custom exponent penalises low credibility more.""" + cfg = ScoringConfig(credibility_exponent=2.0) + w = credibility_weight(0.5, config=cfg) + assert abs(w - 0.25) < 1e-9 + + +# --------------------------------------------------------------------------- +# compute_signal_weight +# --------------------------------------------------------------------------- + + +def test_signal_weight_gates_low_confidence(): + """Documents below confidence floor get zero combined weight.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + sw = compute_signal_weight( + published_at=now, + reference_time=now, + window="7d", + source_credibility=0.8, + extraction_confidence=0.1, # below default 0.2 floor + ) + assert sw.combined == 0.0 + assert sw.confidence_gate == 0.0 + + +def test_signal_weight_fresh_high_credibility(): + """Fresh doc with high credibility and default novelty gets a strong weight.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + sw = compute_signal_weight( + published_at=now, + reference_time=now, + window="7d", + source_credibility=1.0, + novelty_score=0.5, + extraction_confidence=0.8, + ) + # recency=1.0, credibility=1.0, bonus=0.125, gate=1.0 + expected = 1.0 * 1.0 * (1.0 + 0.125) + assert abs(sw.combined - expected) < 1e-9 + + +def test_signal_weight_novelty_bonus(): + """Higher novelty gives a proportionally higher combined weight.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + sw_low = compute_signal_weight(now, now, "7d", 0.8, novelty_score=0.0, extraction_confidence=0.8) + sw_high = compute_signal_weight(now, now, "7d", 0.8, novelty_score=1.0, extraction_confidence=0.8) + assert sw_high.combined > sw_low.combined + + +# --------------------------------------------------------------------------- +# sentiment helpers +# --------------------------------------------------------------------------- + + +def test_sentiment_to_numeric(): + assert sentiment_to_numeric("positive") == 1.0 + assert sentiment_to_numeric("negative") == -1.0 + assert sentiment_to_numeric("neutral") == 0.0 + assert sentiment_to_numeric("mixed") == 0.0 + assert sentiment_to_numeric("unknown") == 0.0 + + +def test_weighted_sentiment_average_empty(): + assert weighted_sentiment_average([]) == 0.0 + + +def test_weighted_sentiment_average_single(): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + sw = compute_signal_weight(now, now, "7d", 0.8, extraction_confidence=0.8) + signals = [WeightedSignal("doc1", sw, sentiment_value=1.0, impact_score=0.7)] + avg = weighted_sentiment_average(signals) + assert abs(avg - 1.0) < 1e-9 # single positive signal → 1.0 + + +def test_weighted_sentiment_average_opposing(): + """Equal-weight opposing signals should cancel to ~0.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + sw = compute_signal_weight(now, now, "7d", 0.8, extraction_confidence=0.8) + signals = [ + WeightedSignal("doc1", sw, sentiment_value=1.0, impact_score=0.5), + WeightedSignal("doc2", sw, sentiment_value=-1.0, impact_score=0.5), + ] + avg = weighted_sentiment_average(signals) + assert abs(avg) < 1e-9 + + +# --------------------------------------------------------------------------- +# market_context_multiplier +# --------------------------------------------------------------------------- + + +def test_market_context_multiplier_none(): + """No market context returns 1.0 (no adjustment).""" + assert market_context_multiplier(None) == 1.0 + + +def test_market_context_multiplier_no_data(): + """MarketContext with no bars returns 1.0.""" + ctx = MarketContext(ticker="AAPL", bars_available=0) + assert market_context_multiplier(ctx) == 1.0 + + +def test_market_context_multiplier_low_volatility(): + """Below-threshold volatility produces no boost.""" + ctx = MarketContext(ticker="AAPL", volatility=0.5, volume_change_pct=10.0, bars_available=5) + assert market_context_multiplier(ctx) == 1.0 + + +def test_market_context_multiplier_high_volatility(): + """Above-threshold volatility produces a boost > 1.0.""" + ctx = MarketContext(ticker="AAPL", volatility=3.0, volume_change_pct=10.0, bars_available=5) + m = market_context_multiplier(ctx) + assert m > 1.0 + assert m <= 1.0 + DEFAULT_CONFIG.volatility_recency_boost_max + DEFAULT_CONFIG.volume_surge_boost + + +def test_market_context_multiplier_volume_surge(): + """Volume surge above threshold adds a boost.""" + ctx = MarketContext(ticker="AAPL", volatility=0.5, volume_change_pct=80.0, bars_available=5) + m = market_context_multiplier(ctx) + assert abs(m - (1.0 + DEFAULT_CONFIG.volume_surge_boost)) < 1e-9 + + +def test_market_context_multiplier_both_triggers(): + """Both volatility and volume surge stack.""" + ctx = MarketContext(ticker="AAPL", volatility=3.0, volume_change_pct=80.0, bars_available=5) + m = market_context_multiplier(ctx) + # Should be > 1.0 + volume_surge_boost alone + assert m > 1.0 + DEFAULT_CONFIG.volume_surge_boost + + +# --------------------------------------------------------------------------- +# compute_signal_weight with market context +# --------------------------------------------------------------------------- + + +def test_signal_weight_with_market_context_boost(): + """Market context with high volatility should increase combined weight.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + ctx = MarketContext(ticker="AAPL", volatility=3.0, volume_change_pct=80.0, bars_available=10) + + sw_no_ctx = compute_signal_weight(now, now, "7d", 0.8, extraction_confidence=0.8) + sw_with_ctx = compute_signal_weight(now, now, "7d", 0.8, extraction_confidence=0.8, market_ctx=ctx) + + assert sw_with_ctx.combined > sw_no_ctx.combined + assert sw_with_ctx.market_ctx_multiplier > 1.0 + assert sw_no_ctx.market_ctx_multiplier == 1.0 + + +def test_signal_weight_market_context_gated_still_zero(): + """Low confidence docs stay at zero even with market context boost.""" + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + ctx = MarketContext(ticker="AAPL", volatility=5.0, volume_change_pct=100.0, bars_available=10) + + sw = compute_signal_weight(now, now, "7d", 0.8, extraction_confidence=0.1, market_ctx=ctx) + assert sw.combined == 0.0 diff --git a/tests/test_aggregation_worker.py b/tests/test_aggregation_worker.py new file mode 100644 index 0000000..6112757 --- /dev/null +++ b/tests/test_aggregation_worker.py @@ -0,0 +1,318 @@ +"""Tests for aggregation worker — rolling window trend summary computation. + +Tests the pure logic functions (no DB required). The async DB functions +are covered by integration tests. +""" +from datetime import datetime, timedelta, timezone + +from services.aggregation.scoring import ( + ScoringConfig, + WeightedSignal, + compute_signal_weight, +) +from services.aggregation.worker import ( + AggregationConfig, + AssembledTrend, + ImpactRow, + assemble_trend_summary, + assemble_trend_with_evidence, + build_weighted_signals, + compute_contradiction_score, + compute_trend_confidence, + derive_trend_direction, + extract_catalysts_and_risks, + rank_evidence, +) +from services.shared.schemas import MarketContext, TrendDirection, TrendWindow + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + +def _make_impact( + doc_id: str = "doc-1", + sentiment: str = "positive", + impact_score: float = 0.7, + catalyst_type: str = "earnings", + confidence: float = 0.8, + source_credibility: float = 0.8, + novelty_score: float = 0.5, + published_at: datetime | None = None, + risks: list[str] | None = None, +) -> ImpactRow: + return ImpactRow( + document_id=doc_id, + confidence=confidence, + novelty_score=novelty_score, + source_credibility=source_credibility, + sentiment=sentiment, + impact_score=impact_score, + catalyst_type=catalyst_type, + key_facts=["some fact"], + risks=risks or [], + published_at=published_at or NOW - timedelta(hours=1), + ) + + +# --------------------------------------------------------------------------- +# build_weighted_signals +# --------------------------------------------------------------------------- + + +def test_build_weighted_signals_basic(): + impacts = [_make_impact("d1"), _make_impact("d2", sentiment="negative")] + signals = build_weighted_signals(impacts, NOW, "7d") + assert len(signals) == 2 + assert signals[0].document_id == "d1" + assert signals[0].sentiment_value == 1.0 + assert signals[1].sentiment_value == -1.0 + assert all(s.weight.combined > 0 for s in signals) + + +def test_build_weighted_signals_low_confidence_gated(): + impacts = [_make_impact("d1", confidence=0.1)] + signals = build_weighted_signals(impacts, NOW, "7d") + assert signals[0].weight.combined == 0.0 + + +# --------------------------------------------------------------------------- +# derive_trend_direction +# --------------------------------------------------------------------------- + + +def test_direction_bullish(): + assert derive_trend_direction(0.5) == TrendDirection.BULLISH + + +def test_direction_bearish(): + assert derive_trend_direction(-0.5) == TrendDirection.BEARISH + + +def test_direction_neutral(): + assert derive_trend_direction(0.05) == TrendDirection.NEUTRAL + + +def test_direction_mixed_high_contradiction(): + assert derive_trend_direction(0.1, contradiction_score=0.2) == TrendDirection.MIXED + + +def test_direction_bullish_despite_contradiction(): + """Strong sentiment overrides contradiction.""" + assert derive_trend_direction(0.5, contradiction_score=0.3) == TrendDirection.BULLISH + + +# --------------------------------------------------------------------------- +# compute_contradiction_score +# --------------------------------------------------------------------------- + + +def test_contradiction_no_signals(): + assert compute_contradiction_score([]) == 0.0 + + +def test_contradiction_all_positive(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + signals = [ + WeightedSignal("d1", sw, sentiment_value=1.0, impact_score=0.5), + WeightedSignal("d2", sw, sentiment_value=1.0, impact_score=0.5), + ] + assert compute_contradiction_score(signals) == 0.0 + + +def test_contradiction_equal_opposing(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + signals = [ + WeightedSignal("d1", sw, sentiment_value=1.0, impact_score=0.5), + WeightedSignal("d2", sw, sentiment_value=-1.0, impact_score=0.5), + ] + score = compute_contradiction_score(signals) + assert abs(score - 0.5) < 1e-4 + + +def test_contradiction_mostly_positive(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + signals = [ + WeightedSignal("d1", sw, sentiment_value=1.0, impact_score=0.8), + WeightedSignal("d2", sw, sentiment_value=1.0, impact_score=0.8), + WeightedSignal("d3", sw, sentiment_value=-1.0, impact_score=0.3), + ] + score = compute_contradiction_score(signals) + assert 0.0 < score < 0.5 + + +# --------------------------------------------------------------------------- +# rank_evidence +# --------------------------------------------------------------------------- + + +def test_rank_evidence_separates_supporting_opposing(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + signals = [ + WeightedSignal("pos1", sw, sentiment_value=1.0, impact_score=0.9), + WeightedSignal("pos2", sw, sentiment_value=1.0, impact_score=0.3), + WeightedSignal("neg1", sw, sentiment_value=-1.0, impact_score=0.7), + WeightedSignal("neutral1", sw, sentiment_value=0.0, impact_score=0.5), + ] + supporting, opposing = rank_evidence(signals) + assert supporting == ["pos1", "pos2"] + assert opposing == ["neg1"] + + +def test_rank_evidence_respects_max(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + signals = [ + WeightedSignal(f"d{i}", sw, sentiment_value=1.0, impact_score=0.5) + for i in range(20) + ] + supporting, opposing = rank_evidence(signals, max_refs=3) + assert len(supporting) == 3 + assert len(opposing) == 0 + + +# --------------------------------------------------------------------------- +# extract_catalysts_and_risks +# --------------------------------------------------------------------------- + + +def test_extract_catalysts_and_risks(): + impacts = [ + _make_impact("d1", catalyst_type="earnings", risks=["regulatory risk"]), + _make_impact("d2", catalyst_type="earnings", risks=["supply chain"]), + _make_impact("d3", catalyst_type="product", risks=["regulatory risk"]), + ] + signals = build_weighted_signals(impacts, NOW, "7d") + catalysts, risks = extract_catalysts_and_risks(impacts, signals) + assert catalysts[0] == "earnings" # highest cumulative weight + assert "product" in catalysts + # Risks should be deduplicated + risk_lower = [r.lower() for r in risks] + assert risk_lower.count("regulatory risk") == 1 + + +# --------------------------------------------------------------------------- +# compute_trend_confidence +# --------------------------------------------------------------------------- + + +def test_confidence_no_signals(): + assert compute_trend_confidence([], 0.0) == 0.0 + + +def test_confidence_increases_with_more_signals(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + few = [WeightedSignal(f"d{i}", sw, 1.0, 0.5) for i in range(2)] + many = [WeightedSignal(f"d{i}", sw, 1.0, 0.5) for i in range(15)] + c_few = compute_trend_confidence(few, 0.0) + c_many = compute_trend_confidence(many, 0.0) + assert c_many > c_few + + +def test_confidence_penalized_by_contradiction(): + sw = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + signals = [WeightedSignal(f"d{i}", sw, 1.0, 0.5) for i in range(5)] + c_low = compute_trend_confidence(signals, 0.0) + c_high = compute_trend_confidence(signals, 0.5) + assert c_high < c_low + + +# --------------------------------------------------------------------------- +# assemble_trend_summary +# --------------------------------------------------------------------------- + + +def test_assemble_trend_summary_bullish(): + impacts = [ + _make_impact("d1", sentiment="positive", impact_score=0.8), + _make_impact("d2", sentiment="positive", impact_score=0.6), + ] + signals = build_weighted_signals(impacts, NOW, "7d") + summary = assemble_trend_summary("AAPL", "7d", signals, impacts, reference_time=NOW) + + assert summary.entity_id == "AAPL" + assert summary.window == TrendWindow.SEVEN_DAY + assert summary.trend_direction == TrendDirection.BULLISH + assert summary.trend_strength > 0 + assert summary.confidence > 0 + assert len(summary.top_supporting_evidence) > 0 + assert summary.generated_at == NOW + + +def test_assemble_trend_summary_mixed(): + impacts = [ + _make_impact("d1", sentiment="positive", impact_score=0.5), + _make_impact("d2", sentiment="negative", impact_score=0.5), + ] + signals = build_weighted_signals(impacts, NOW, "7d") + summary = assemble_trend_summary("TSLA", "7d", signals, impacts, reference_time=NOW) + + # Equal opposing signals → neutral or mixed + assert summary.trend_direction in (TrendDirection.NEUTRAL, TrendDirection.MIXED) + assert summary.contradiction_score > 0 + + +def test_assemble_trend_summary_empty(): + summary = assemble_trend_summary("AAPL", "7d", [], [], reference_time=NOW) + assert summary.trend_direction == TrendDirection.NEUTRAL + assert summary.trend_strength == 0.0 + assert summary.confidence == 0.0 + + +def test_assemble_trend_summary_with_market_context(): + impacts = [_make_impact("d1")] + ctx = MarketContext(ticker="AAPL", volatility=3.0, bars_available=5) + signals = build_weighted_signals(impacts, NOW, "7d", market_ctx=ctx) + summary = assemble_trend_summary("AAPL", "7d", signals, impacts, market_ctx=ctx, reference_time=NOW) + assert summary.market_context is not None + assert summary.market_context.ticker == "AAPL" + + +# --------------------------------------------------------------------------- +# AggregationConfig +# --------------------------------------------------------------------------- + + +def test_aggregation_config_defaults(): + cfg = AggregationConfig() + assert len(cfg.effective_windows()) == len(TrendWindow) + assert isinstance(cfg.effective_scoring(), ScoringConfig) + + +def test_aggregation_config_custom_windows(): + cfg = AggregationConfig(windows=["7d", "30d"]) + assert cfg.effective_windows() == ["7d", "30d"] + + +# --------------------------------------------------------------------------- +# assemble_trend_with_evidence +# --------------------------------------------------------------------------- + + +def test_assemble_trend_with_evidence_returns_ranked_details(): + impacts = [ + _make_impact("d1", sentiment="positive", impact_score=0.8), + _make_impact("d2", sentiment="negative", impact_score=0.6), + _make_impact("d3", sentiment="positive", impact_score=0.5), + ] + signals = build_weighted_signals(impacts, NOW, "7d") + result = assemble_trend_with_evidence("AAPL", "7d", signals, impacts, reference_time=NOW) + + assert isinstance(result, AssembledTrend) + assert result.summary.entity_id == "AAPL" + # Supporting evidence should contain the positive docs + assert len(result.supporting_evidence) == 2 + assert all(e.sentiment_value > 0 for e in result.supporting_evidence) + # Opposing evidence should contain the negative doc + assert len(result.opposing_evidence) == 1 + assert result.opposing_evidence[0].document_id == "d2" + # Rank scores should be positive + assert all(e.rank_score > 0 for e in result.supporting_evidence) + assert all(e.rank_score > 0 for e in result.opposing_evidence) + # Summary evidence IDs should match + assert result.summary.top_supporting_evidence == [e.document_id for e in result.supporting_evidence] + assert result.summary.top_opposing_evidence == [e.document_id for e in result.opposing_evidence] + + +def test_assemble_trend_with_evidence_empty_signals(): + result = assemble_trend_with_evidence("AAPL", "7d", [], [], reference_time=NOW) + assert result.supporting_evidence == [] + assert result.opposing_evidence == [] + assert result.summary.trend_direction == TrendDirection.NEUTRAL diff --git a/tests/test_alerting.py b/tests/test_alerting.py new file mode 100644 index 0000000..81a0b96 --- /dev/null +++ b/tests/test_alerting.py @@ -0,0 +1,306 @@ +"""Tests for operational alerting rules. + +Requirements: 12.3 +""" +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest + +from services.shared.alerting import ( + Alert, + AlertState, + check_analytical_lag, + check_broker_issues, + check_schema_failure_spike, + check_source_failures, + evaluate_alerts, +) +from services.shared.config import AlertingConfig + + +@pytest.fixture +def config(): + return AlertingConfig( + source_failure_threshold=3, + source_failure_window_hours=6, + schema_failure_rate_threshold=0.3, + schema_failure_window_hours=1, + lake_lag_threshold_minutes=60, + broker_error_threshold=3, + broker_error_window_hours=1, + check_interval_seconds=120, + ) + + +@pytest.fixture +def state(): + return AlertState() + + +# --------------------------------------------------------------------------- +# AlertState unit tests +# --------------------------------------------------------------------------- + + +class TestAlertState: + def test_fire_new_alert_returns_true(self, state): + alert = Alert(rule="source_failures", severity="warning", summary="test", + details={"key": "src1"}) + assert state.fire(alert) is True + + def test_fire_existing_alert_returns_false(self, state): + alert = Alert(rule="source_failures", severity="warning", summary="test", + details={"key": "src1"}) + state.fire(alert) + assert state.fire(alert) is False + + def test_resolve_active_returns_true(self, state): + alert = Alert(rule="source_failures", severity="warning", summary="test", + details={"key": "src1"}) + state.fire(alert) + assert state.resolve("source_failures", "src1") is True + + def test_resolve_inactive_returns_false(self, state): + assert state.resolve("source_failures", "src1") is False + + def test_is_firing(self, state): + alert = Alert(rule="broker_issues", severity="critical", summary="test", + details={"key": "global"}) + assert state.is_firing("broker_issues", "global") is False + state.fire(alert) + assert state.is_firing("broker_issues", "global") is True + + def test_multiple_alerts_same_rule_different_keys(self, state): + a1 = Alert(rule="source_failures", severity="warning", summary="s1", + details={"key": "src1"}) + a2 = Alert(rule="source_failures", severity="warning", summary="s2", + details={"key": "src2"}) + assert state.fire(a1) is True + assert state.fire(a2) is True + assert state.is_firing("source_failures", "src1") is True + assert state.is_firing("source_failures", "src2") is True + state.resolve("source_failures", "src1") + assert state.is_firing("source_failures", "src1") is False + assert state.is_firing("source_failures", "src2") is True + + +# --------------------------------------------------------------------------- +# check_source_failures +# --------------------------------------------------------------------------- + + +class TestCheckSourceFailures: + @pytest.mark.asyncio + async def test_returns_alerts_for_failing_sources(self, config): + mock_pool = AsyncMock() + mock_pool.fetch.return_value = [ + { + "source_id": "uuid-1", + "consecutive_failures": 3, + "source_type": "news_api", + "source_name": "reuters", + "ticker": "AAPL", + }, + ] + + alerts = await check_source_failures(mock_pool, config) + assert len(alerts) == 1 + assert alerts[0].rule == "source_failures" + assert alerts[0].severity == "warning" + assert "AAPL" in alerts[0].summary + assert alerts[0].details["source_id"] == "uuid-1" + + @pytest.mark.asyncio + async def test_returns_empty_when_no_failures(self, config): + mock_pool = AsyncMock() + mock_pool.fetch.return_value = [] + + alerts = await check_source_failures(mock_pool, config) + assert alerts == [] + + +# --------------------------------------------------------------------------- +# check_schema_failure_spike +# --------------------------------------------------------------------------- + + +class TestCheckSchemaFailureSpike: + @pytest.mark.asyncio + async def test_fires_when_rate_exceeds_threshold(self, config): + mock_pool = AsyncMock() + mock_pool.fetchrow.return_value = {"total": 100, "failed": 40} + + alerts = await check_schema_failure_spike(mock_pool, config) + assert len(alerts) == 1 + assert alerts[0].rule == "schema_failure_spike" + assert alerts[0].details["failure_rate"] == 0.4 + + @pytest.mark.asyncio + async def test_critical_severity_above_50_percent(self, config): + mock_pool = AsyncMock() + mock_pool.fetchrow.return_value = {"total": 100, "failed": 60} + + alerts = await check_schema_failure_spike(mock_pool, config) + assert len(alerts) == 1 + assert alerts[0].severity == "critical" + + @pytest.mark.asyncio + async def test_no_alert_below_threshold(self, config): + mock_pool = AsyncMock() + mock_pool.fetchrow.return_value = {"total": 100, "failed": 10} + + alerts = await check_schema_failure_spike(mock_pool, config) + assert alerts == [] + + @pytest.mark.asyncio + async def test_no_alert_when_no_extractions(self, config): + mock_pool = AsyncMock() + mock_pool.fetchrow.return_value = {"total": 0, "failed": 0} + + alerts = await check_schema_failure_spike(mock_pool, config) + assert alerts == [] + + +# --------------------------------------------------------------------------- +# check_analytical_lag +# --------------------------------------------------------------------------- + + +class TestCheckAnalyticalLag: + @pytest.mark.asyncio + async def test_fires_for_stale_tables(self, config): + mock_pool = AsyncMock() + stale_time = datetime(2026, 4, 10, 10, 0, 0, tzinfo=timezone.utc) + mock_pool.fetch.return_value = [ + {"table_name": "market_bars", "last_publish": stale_time}, + ] + + alerts = await check_analytical_lag(mock_pool, config) + assert len(alerts) == 1 + assert alerts[0].rule == "analytical_lag" + assert "market_bars" in alerts[0].summary + + @pytest.mark.asyncio + async def test_no_alert_when_recent(self, config): + mock_pool = AsyncMock() + mock_pool.fetch.return_value = [] + + alerts = await check_analytical_lag(mock_pool, config) + assert alerts == [] + + +# --------------------------------------------------------------------------- +# check_broker_issues +# --------------------------------------------------------------------------- + + +class TestCheckBrokerIssues: + @pytest.mark.asyncio + async def test_fires_on_consecutive_errors(self, config): + mock_pool = AsyncMock() + mock_pool.fetch.return_value = [{"error_count": 5}] + + alerts = await check_broker_issues(mock_pool, config) + assert len(alerts) == 1 + assert alerts[0].rule == "broker_issues" + assert alerts[0].severity == "critical" + + @pytest.mark.asyncio + async def test_no_alert_below_threshold(self, config): + mock_pool = AsyncMock() + mock_pool.fetch.return_value = [{"error_count": 1}] + + alerts = await check_broker_issues(mock_pool, config) + assert alerts == [] + + @pytest.mark.asyncio + async def test_no_alert_when_no_events(self, config): + mock_pool = AsyncMock() + mock_pool.fetch.return_value = [] + + alerts = await check_broker_issues(mock_pool, config) + assert alerts == [] + + +# --------------------------------------------------------------------------- +# evaluate_alerts integration +# --------------------------------------------------------------------------- + + +class TestEvaluateAlerts: + @pytest.mark.asyncio + async def test_newly_fired_alerts_returned(self, config, state): + mock_pool = AsyncMock() + + with patch("services.shared.alerting.check_source_failures") as mock_src, \ + patch("services.shared.alerting.check_schema_failure_spike") as mock_schema, \ + patch("services.shared.alerting.check_analytical_lag") as mock_lag, \ + patch("services.shared.alerting.check_broker_issues") as mock_broker: + + mock_src.return_value = [ + Alert(rule="source_failures", severity="warning", + summary="src fail", details={"key": "s1"}), + ] + mock_schema.return_value = [] + mock_lag.return_value = [] + mock_broker.return_value = [] + + fired = await evaluate_alerts(mock_pool, config, state) + assert len(fired) == 1 + assert fired[0].rule == "source_failures" + assert state.is_firing("source_failures", "s1") + + @pytest.mark.asyncio + async def test_repeated_alert_not_returned_again(self, config, state): + mock_pool = AsyncMock() + alert = Alert(rule="broker_issues", severity="critical", + summary="broker down", details={"key": "global"}) + + with patch("services.shared.alerting.check_source_failures", return_value=[]), \ + patch("services.shared.alerting.check_schema_failure_spike", return_value=[]), \ + patch("services.shared.alerting.check_analytical_lag", return_value=[]), \ + patch("services.shared.alerting.check_broker_issues", return_value=[alert]): + + fired1 = await evaluate_alerts(mock_pool, config, state) + assert len(fired1) == 1 + + fired2 = await evaluate_alerts(mock_pool, config, state) + assert len(fired2) == 0 + + @pytest.mark.asyncio + async def test_resolved_alert_clears_state(self, config, state): + mock_pool = AsyncMock() + alert = Alert(rule="broker_issues", severity="critical", + summary="broker down", details={"key": "global"}) + + with patch("services.shared.alerting.check_source_failures", return_value=[]), \ + patch("services.shared.alerting.check_schema_failure_spike", return_value=[]), \ + patch("services.shared.alerting.check_analytical_lag", return_value=[]), \ + patch("services.shared.alerting.check_broker_issues") as mock_broker: + + # Fire + mock_broker.return_value = [alert] + await evaluate_alerts(mock_pool, config, state) + assert state.is_firing("broker_issues", "global") + + # Resolve + mock_broker.return_value = [] + await evaluate_alerts(mock_pool, config, state) + assert not state.is_firing("broker_issues", "global") + + @pytest.mark.asyncio + async def test_rule_exception_does_not_crash(self, config, state): + mock_pool = AsyncMock() + + with patch("services.shared.alerting.check_source_failures", + side_effect=Exception("db down")), \ + patch("services.shared.alerting.check_schema_failure_spike", return_value=[]), \ + patch("services.shared.alerting.check_analytical_lag", return_value=[]), \ + patch("services.shared.alerting.check_broker_issues", return_value=[]): + + # Should not raise + fired = await evaluate_alerts(mock_pool, config, state) + assert fired == [] diff --git a/tests/test_audit.py b/tests/test_audit.py new file mode 100644 index 0000000..6aab028 --- /dev/null +++ b/tests/test_audit.py @@ -0,0 +1,160 @@ +"""Tests for the execution audit trail module. + +Validates audit event construction, event type constants, and the +convenience helpers that record each stage of the execution pipeline. +""" +from services.shared.audit import ( + AUDIT_ORDER_CANCELLED, + AUDIT_ORDER_DUPLICATE, + AUDIT_ORDER_FILLED, + AUDIT_ORDER_REJECTED, + AUDIT_ORDER_SUBMITTED, + AUDIT_POSITION_CLOSED, + AUDIT_POSITION_OPENED, + AUDIT_POSITION_UPDATED, + AUDIT_RECOMMENDATION_GENERATED, + AUDIT_RECOMMENDATION_SUPPRESSED, + AUDIT_RISK_EVALUATED, + AUDIT_RISK_REJECTED, + AUDIT_TRADING_MODE_CHANGED, +) + + +# --------------------------------------------------------------------------- +# Event type constants +# --------------------------------------------------------------------------- + + +class TestAuditEventTypes: + """Verify event type constants are well-formed and distinct.""" + + def test_recommendation_events(self): + assert AUDIT_RECOMMENDATION_GENERATED == "recommendation.generated" + assert AUDIT_RECOMMENDATION_SUPPRESSED == "recommendation.suppressed" + + def test_risk_events(self): + assert AUDIT_RISK_EVALUATED == "risk.evaluated" + assert AUDIT_RISK_REJECTED == "risk.rejected" + + def test_order_events(self): + assert AUDIT_ORDER_SUBMITTED == "order.submitted" + assert AUDIT_ORDER_FILLED == "order.filled" + assert AUDIT_ORDER_REJECTED == "order.rejected" + assert AUDIT_ORDER_CANCELLED == "order.cancelled" + assert AUDIT_ORDER_DUPLICATE == "order.duplicate_prevented" + + def test_position_events(self): + assert AUDIT_POSITION_OPENED == "position.opened" + assert AUDIT_POSITION_CLOSED == "position.closed" + assert AUDIT_POSITION_UPDATED == "position.updated" + + def test_trading_mode_event(self): + assert AUDIT_TRADING_MODE_CHANGED == "trading.mode_changed" + + def test_all_event_types_unique(self): + all_types = [ + AUDIT_RECOMMENDATION_GENERATED, + AUDIT_RECOMMENDATION_SUPPRESSED, + AUDIT_RISK_EVALUATED, + AUDIT_RISK_REJECTED, + AUDIT_ORDER_SUBMITTED, + AUDIT_ORDER_FILLED, + AUDIT_ORDER_REJECTED, + AUDIT_ORDER_CANCELLED, + AUDIT_ORDER_DUPLICATE, + AUDIT_POSITION_OPENED, + AUDIT_POSITION_CLOSED, + AUDIT_POSITION_UPDATED, + AUDIT_TRADING_MODE_CHANGED, + ] + assert len(all_types) == len(set(all_types)) + + def test_event_types_follow_dot_notation(self): + """All event types should follow entity.action pattern.""" + all_types = [ + AUDIT_RECOMMENDATION_GENERATED, + AUDIT_RECOMMENDATION_SUPPRESSED, + AUDIT_RISK_EVALUATED, + AUDIT_RISK_REJECTED, + AUDIT_ORDER_SUBMITTED, + AUDIT_ORDER_FILLED, + AUDIT_ORDER_REJECTED, + AUDIT_ORDER_CANCELLED, + AUDIT_ORDER_DUPLICATE, + AUDIT_POSITION_OPENED, + AUDIT_POSITION_CLOSED, + AUDIT_POSITION_UPDATED, + AUDIT_TRADING_MODE_CHANGED, + ] + for t in all_types: + assert "." in t, f"Event type {t} should use dot notation" + parts = t.split(".") + assert len(parts) == 2, f"Event type {t} should have exactly one dot" + assert all(p for p in parts), f"Event type {t} has empty parts" + + +# --------------------------------------------------------------------------- +# Module imports and structure +# --------------------------------------------------------------------------- + + +class TestAuditModuleStructure: + """Verify the audit module exports the expected functions.""" + + def test_record_audit_event_exists(self): + from services.shared.audit import record_audit_event + assert callable(record_audit_event) + + def test_convenience_helpers_exist(self): + from services.shared.audit import ( + audit_recommendation_generated, + audit_risk_evaluated, + audit_order_submitted, + audit_order_filled, + audit_order_rejected, + audit_order_cancelled, + audit_duplicate_prevented, + audit_position_change, + audit_trading_mode_changed, + ) + for fn in [ + audit_recommendation_generated, + audit_risk_evaluated, + audit_order_submitted, + audit_order_filled, + audit_order_rejected, + audit_order_cancelled, + audit_duplicate_prevented, + audit_position_change, + audit_trading_mode_changed, + ]: + assert callable(fn) + + def test_query_helpers_exist(self): + from services.shared.audit import ( + get_order_audit_trail, + get_entity_audit_trail, + ) + assert callable(get_order_audit_trail) + assert callable(get_entity_audit_trail) + + +# --------------------------------------------------------------------------- +# Broker service audit integration +# --------------------------------------------------------------------------- + + +class TestBrokerServiceAuditImports: + """Verify the broker service uses audit functions from the audit module.""" + + def test_broker_service_has_audit_calls(self): + """The broker service module should reference audit functions.""" + import inspect + import services.adapters.broker_service as bs + + source = inspect.getsource(bs) + assert "audit_order_submitted" in source + assert "audit_order_filled" in source + assert "audit_order_rejected" in source + assert "audit_risk_evaluated" in source + assert "audit_duplicate_prevented" in source diff --git a/tests/test_broker_adapter.py b/tests/test_broker_adapter.py new file mode 100644 index 0000000..a2fce9e --- /dev/null +++ b/tests/test_broker_adapter.py @@ -0,0 +1,417 @@ +"""Tests for the broker API adapter interface and Alpaca implementation. + +Validates data structures, request building, response parsing, and fail-closed behavior. +""" +from services.adapters.broker_adapter import ( + AccountInfo, + AlpacaBrokerAdapter, + BrokerDataAdapter, + OrderEventType, + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, + TradingMode, +) + + +# --- Fake Alpaca responses --- + +ALPACA_ORDER_RESPONSE = { + "id": "order-abc-123", + "client_order_id": "client-001", + "status": "accepted", + "symbol": "AAPL", + "side": "buy", + "qty": "10", + "filled_qty": "0", + "filled_avg_price": None, + "type": "market", + "time_in_force": "day", + "created_at": "2026-04-11T14:00:00Z", +} + +ALPACA_FILLED_ORDER = { + "id": "order-def-456", + "status": "filled", + "symbol": "AAPL", + "side": "buy", + "qty": "10", + "filled_qty": "10", + "filled_avg_price": "172.50", + "type": "market", + "time_in_force": "day", +} + +ALPACA_REJECTED_ORDER = { + "id": "order-ghi-789", + "status": "rejected", + "symbol": "AAPL", + "side": "sell", + "qty": "100", + "filled_qty": "0", + "filled_avg_price": None, +} + +ALPACA_POSITION = { + "symbol": "AAPL", + "qty": "10", + "avg_entry_price": "172.50", + "current_price": "175.00", + "unrealized_pl": "25.00", + "market_value": "1750.00", + "side": "long", +} + +ALPACA_ACCOUNT = { + "id": "acct-001", + "buying_power": "50000.00", + "cash": "25000.00", + "portfolio_value": "75000.00", + "currency": "USD", +} + + +# --- Enum tests --- + + +class TestBrokerEnums: + def test_order_side_values(self): + assert OrderSide.BUY.value == "buy" + assert OrderSide.SELL.value == "sell" + + def test_order_type_values(self): + assert OrderType.MARKET.value == "market" + assert OrderType.LIMIT.value == "limit" + assert OrderType.STOP.value == "stop" + assert OrderType.STOP_LIMIT.value == "stop_limit" + + def test_order_status_values(self): + assert OrderStatus.PENDING.value == "pending" + assert OrderStatus.FILLED.value == "filled" + assert OrderStatus.REJECTED.value == "rejected" + + def test_trading_mode_values(self): + assert TradingMode.PAPER.value == "paper" + assert TradingMode.LIVE.value == "live" + + def test_order_event_type_values(self): + assert OrderEventType.SUBMITTED.value == "submitted" + assert OrderEventType.FILL.value == "fill" + assert OrderEventType.CANCELLED.value == "cancelled" + + +# --- OrderRequest tests --- + + +class TestOrderRequest: + def test_basic_market_order(self): + req = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + ) + assert req.ticker == "AAPL" + assert req.side == OrderSide.BUY + assert req.quantity == 10 + assert req.order_type == OrderType.MARKET + assert req.time_in_force == "day" + assert req.idempotency_key # auto-generated + + def test_limit_order(self): + req = OrderRequest( + ticker="MSFT", + side=OrderSide.SELL, + quantity=5, + order_type=OrderType.LIMIT, + limit_price=400.0, + ) + assert req.order_type == OrderType.LIMIT + assert req.limit_price == 400.0 + + def test_custom_idempotency_key(self): + req = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=1, + idempotency_key="my-key-123", + ) + assert req.idempotency_key == "my-key-123" + + def test_to_dict(self): + req = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + order_type=OrderType.LIMIT, + limit_price=170.0, + idempotency_key="key-1", + ) + d = req.to_dict() + assert d["ticker"] == "AAPL" + assert d["side"] == "buy" + assert d["quantity"] == 10 + assert d["order_type"] == "limit" + assert d["limit_price"] == 170.0 + assert d["idempotency_key"] == "key-1" + + def test_to_dict_omits_none_prices(self): + req = OrderRequest(ticker="AAPL", side=OrderSide.BUY, quantity=1) + d = req.to_dict() + assert "limit_price" not in d + assert "stop_price" not in d + + +# --- OrderResponse tests --- + + +class TestOrderResponse: + def test_ok_when_accepted(self): + resp = OrderResponse( + broker_order_id="abc", + status=OrderStatus.ACCEPTED, + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + ) + assert resp.ok is True + + def test_not_ok_when_rejected(self): + resp = OrderResponse( + broker_order_id="abc", + status=OrderStatus.REJECTED, + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + error="insufficient funds", + ) + assert resp.ok is False + + def test_not_ok_when_error(self): + resp = OrderResponse( + broker_order_id="abc", + status=OrderStatus.SUBMITTED, + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + error="network failure", + ) + assert resp.ok is False + + def test_to_dict(self): + resp = OrderResponse( + broker_order_id="order-1", + status=OrderStatus.FILLED, + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + filled_quantity=10, + filled_avg_price=172.5, + ) + d = resp.to_dict() + assert d["broker_order_id"] == "order-1" + assert d["status"] == "filled" + assert d["filled_avg_price"] == 172.5 + + +# --- PositionInfo tests --- + + +class TestPositionInfo: + def test_basic_position(self): + pos = PositionInfo( + ticker="AAPL", + quantity=10, + avg_entry_price=172.5, + current_price=175.0, + unrealized_pnl=25.0, + market_value=1750.0, + ) + assert pos.ticker == "AAPL" + assert pos.side == "long" + + def test_to_dict(self): + pos = PositionInfo( + ticker="AAPL", + quantity=10, + avg_entry_price=172.5, + current_price=175.0, + unrealized_pnl=25.0, + market_value=1750.0, + side="short", + ) + d = pos.to_dict() + assert d["side"] == "short" + assert d["unrealized_pnl"] == 25.0 + + +# --- AccountInfo tests --- + + +class TestAccountInfo: + def test_basic_account(self): + acct = AccountInfo( + account_id="acct-1", + buying_power=50000, + cash=25000, + portfolio_value=75000, + ) + assert acct.mode == TradingMode.PAPER + assert acct.currency == "USD" + + def test_to_dict(self): + acct = AccountInfo( + account_id="acct-1", + buying_power=50000, + cash=25000, + portfolio_value=75000, + mode=TradingMode.LIVE, + ) + d = acct.to_dict() + assert d["mode"] == "live" + assert d["portfolio_value"] == 75000 + + +# --- AlpacaBrokerAdapter tests --- + + +class TestAlpacaSourceType: + def test_source_type(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + assert adapter.source_type() == "broker" + + def test_inherits_broker_data_adapter(self): + assert issubclass(AlpacaBrokerAdapter, BrokerDataAdapter) + + def test_bucket_name(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + assert adapter.bucket_name() == "stonks-raw-broker" + + def test_default_mode_is_paper(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + assert adapter.mode == TradingMode.PAPER + + def test_paper_base_url(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s", mode=TradingMode.PAPER) + assert "paper" in adapter.base_url + + def test_live_base_url(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s", mode=TradingMode.LIVE) + assert "paper" not in adapter.base_url + + def test_custom_base_url(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s", base_url="http://localhost:8080/") + assert adapter.base_url == "http://localhost:8080" + + +class TestAlpacaHeaders: + def test_headers_contain_api_keys(self): + adapter = AlpacaBrokerAdapter(api_key="my-key", api_secret="my-secret") + headers = adapter._headers() + assert headers["APCA-API-KEY-ID"] == "my-key" + assert headers["APCA-API-SECRET-KEY"] == "my-secret" + assert headers["Content-Type"] == "application/json" + + +class TestAlpacaBuildFetchUrl: + def setup_method(self): + self.adapter = AlpacaBrokerAdapter( + api_key="k", api_secret="s", base_url="https://paper-api.alpaca.markets" + ) + + def test_positions_url(self): + url = self.adapter._build_fetch_url("AAPL", "positions") + assert url == "https://paper-api.alpaca.markets/v2/positions/AAPL" + + def test_orders_url(self): + url = self.adapter._build_fetch_url("AAPL", "orders") + assert "v2/orders" in url + assert "symbols=AAPL" in url + + def test_account_url(self): + url = self.adapter._build_fetch_url("AAPL", "account") + assert url == "https://paper-api.alpaca.markets/v2/account" + + def test_default_is_positions(self): + url = self.adapter._build_fetch_url("AAPL", "unknown") + assert "/v2/positions/AAPL" in url + + +class TestAlpacaParseOrderResponse: + def setup_method(self): + self.adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + + def test_parse_accepted_order(self): + resp = self.adapter._parse_order_response(ALPACA_ORDER_RESPONSE) + assert resp.broker_order_id == "order-abc-123" + assert resp.status == OrderStatus.ACCEPTED + assert resp.ticker == "AAPL" + assert resp.side == OrderSide.BUY + assert resp.quantity == 10 + assert resp.filled_quantity == 0 + assert resp.filled_avg_price is None + + def test_parse_filled_order(self): + resp = self.adapter._parse_order_response(ALPACA_FILLED_ORDER) + assert resp.status == OrderStatus.FILLED + assert resp.filled_quantity == 10 + assert resp.filled_avg_price == 172.5 + + def test_parse_rejected_order(self): + resp = self.adapter._parse_order_response(ALPACA_REJECTED_ORDER) + assert resp.status == OrderStatus.REJECTED + assert resp.ok is False + + def test_parse_unknown_status_defaults_to_pending(self): + data = {**ALPACA_ORDER_RESPONSE, "status": "some_new_status"} + resp = self.adapter._parse_order_response(data) + assert resp.status == OrderStatus.PENDING + + def test_parse_sell_side(self): + data = {**ALPACA_ORDER_RESPONSE, "side": "sell"} + resp = self.adapter._parse_order_response(data) + assert resp.side == OrderSide.SELL + + +class TestAlpacaParsePosition: + def setup_method(self): + self.adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + + def test_parse_position(self): + pos = self.adapter._parse_position(ALPACA_POSITION) + assert pos.ticker == "AAPL" + assert pos.quantity == 10 + assert pos.avg_entry_price == 172.5 + assert pos.current_price == 175.0 + assert pos.unrealized_pnl == 25.0 + assert pos.market_value == 1750.0 + assert pos.side == "long" + + def test_parse_position_missing_fields(self): + pos = self.adapter._parse_position({"symbol": "TSLA"}) + assert pos.ticker == "TSLA" + assert pos.quantity == 0 + assert pos.avg_entry_price == 0 + + +class TestAlpacaErrorResult: + def test_error_result_fields(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + result = adapter._error_result("AAPL", "rate limited", 150.0, http_status=429, raw=b"slow down") + assert not result.ok + assert result.error == "rate limited" + assert result.http_status == 429 + assert result.response_time_ms == 150.0 + assert result.raw_payload == b"slow down" + assert result.metadata["provider"] == "alpaca" + assert result.metadata["mode"] == "paper" + assert result.source_type == "broker" + + def test_error_result_defaults(self): + adapter = AlpacaBrokerAdapter(api_key="k", api_secret="s") + result = adapter._error_result("MSFT", "timeout", 200.0) + assert result.http_status is None + assert result.raw_payload == b"" + assert result.ticker == "MSFT" diff --git a/tests/test_broker_service.py b/tests/test_broker_service.py new file mode 100644 index 0000000..0d90f57 --- /dev/null +++ b/tests/test_broker_service.py @@ -0,0 +1,261 @@ +"""Tests for the broker service - sandbox integration wiring. + +Validates job parsing, risk evaluation integration, order building, +and the overall process_order_job flow using a mock Alpaca adapter. +""" +import pytest + +from services.adapters.broker_adapter import ( + AlpacaBrokerAdapter, + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + TradingMode, +) +from services.adapters.broker_service import ( + build_order_request, + build_proposed_order, + generate_idempotency_key, +) +from services.risk.engine import ( + AccountRiskState, + PortfolioRiskConfig, + ProposedOrder, + TradingMode as RiskTradingMode, + evaluate_order, +) +from services.shared.redis_keys import QUEUE_BROKER + + +# --------------------------------------------------------------------------- +# build_order_request tests +# --------------------------------------------------------------------------- + + +class TestBuildOrderRequest: + def test_basic_buy_market(self): + job = { + "ticker": "AAPL", + "side": "buy", + "quantity": 10, + "order_type": "market", + "idempotency_key": "key-1", + } + req = build_order_request(job) + assert req.ticker == "AAPL" + assert req.side == OrderSide.BUY + assert req.quantity == 10 + assert req.order_type == OrderType.MARKET + assert req.idempotency_key == "key-1" + + def test_sell_limit_order(self): + job = { + "ticker": "MSFT", + "side": "sell", + "quantity": 5, + "order_type": "limit", + "limit_price": 400.0, + } + req = build_order_request(job) + assert req.side == OrderSide.SELL + assert req.order_type == OrderType.LIMIT + assert req.limit_price == 400.0 + + def test_stop_order(self): + job = { + "ticker": "TSLA", + "side": "sell", + "quantity": 3, + "order_type": "stop", + "stop_price": 200.0, + } + req = build_order_request(job) + assert req.order_type == OrderType.STOP + assert req.stop_price == 200.0 + + def test_defaults(self): + job = {"ticker": "GOOG"} + req = build_order_request(job) + assert req.side == OrderSide.BUY + assert req.quantity == 0 + assert req.order_type == OrderType.MARKET + assert req.time_in_force == "day" + assert req.idempotency_key # deterministic from job content + + def test_deterministic_key_without_explicit(self): + """Without an explicit key, the same job produces the same key.""" + job = {"ticker": "AAPL", "side": "buy", "quantity": 10} + req1 = build_order_request(job) + req2 = build_order_request(job) + assert req1.idempotency_key == req2.idempotency_key + + def test_custom_time_in_force(self): + job = {"ticker": "AAPL", "time_in_force": "gtc"} + req = build_order_request(job) + assert req.time_in_force == "gtc" + + +# --------------------------------------------------------------------------- +# build_proposed_order tests +# --------------------------------------------------------------------------- + + +class TestBuildProposedOrder: + def test_basic_proposed_order(self): + job = { + "ticker": "AAPL", + "side": "buy", + "quantity": 10, + "estimated_value": 1500.0, + "confidence": 0.85, + "sector": "technology", + "recommendation_id": "rec-123", + } + proposed = build_proposed_order(job) + assert proposed.ticker == "AAPL" + assert proposed.action == "buy" + assert proposed.quantity == 10 + assert proposed.estimated_value == 1500.0 + assert proposed.confidence == 0.85 + assert proposed.sector == "technology" + assert proposed.recommendation_id == "rec-123" + + def test_defaults(self): + job = {"ticker": "GOOG"} + proposed = build_proposed_order(job) + assert proposed.action == "buy" + assert proposed.quantity == 0 + assert proposed.estimated_value == 0 + assert proposed.sector == "" + assert proposed.recommendation_id is None + + +# --------------------------------------------------------------------------- +# Risk evaluation integration with broker service flow +# --------------------------------------------------------------------------- + + +class TestRiskEvaluationIntegration: + """Verify that risk evaluation correctly gates order submission.""" + + def test_order_passes_risk_in_paper_mode(self): + config = PortfolioRiskConfig(trading_mode=RiskTradingMode.PAPER) + state = AccountRiskState( + portfolio_value=100_000.0, + cash=50_000.0, + buying_power=50_000.0, + ) + proposed = ProposedOrder( + ticker="AAPL", + action="buy", + quantity=10, + estimated_value=1500.0, + sector="technology", + ) + result = evaluate_order(proposed, config, state) + assert result.eligible + assert result.allowed_mode == RiskTradingMode.PAPER + + def test_order_blocked_when_trading_disabled(self): + config = PortfolioRiskConfig(trading_mode=RiskTradingMode.DISABLED) + proposed = ProposedOrder(ticker="AAPL", quantity=10, estimated_value=1500.0) + result = evaluate_order(proposed, config) + assert not result.eligible + assert "disabled" in result.rejection_reasons[0].lower() + + def test_order_blocked_by_position_size(self): + config = PortfolioRiskConfig(trading_mode=RiskTradingMode.PAPER) + config.position_limits.max_position_value = 1000.0 + state = AccountRiskState(portfolio_value=100_000.0) + proposed = ProposedOrder( + ticker="AAPL", + quantity=100, + estimated_value=15_000.0, + ) + result = evaluate_order(proposed, config, state) + assert not result.eligible + + +# --------------------------------------------------------------------------- +# Alpaca adapter sandbox mode verification +# --------------------------------------------------------------------------- + + +class TestAlpacaSandboxMode: + def test_paper_mode_uses_sandbox_url(self): + adapter = AlpacaBrokerAdapter( + api_key="test-key", + api_secret="test-secret", + mode=TradingMode.PAPER, + ) + assert adapter.mode == TradingMode.PAPER + assert "paper" in adapter.base_url + + def test_custom_sandbox_url(self): + adapter = AlpacaBrokerAdapter( + api_key="test-key", + api_secret="test-secret", + mode=TradingMode.PAPER, + base_url="https://paper-api.alpaca.markets", + ) + assert adapter.base_url == "https://paper-api.alpaca.markets" + + def test_headers_set_correctly(self): + adapter = AlpacaBrokerAdapter( + api_key="pk-test", + api_secret="sk-test", + ) + headers = adapter._headers() + assert headers["APCA-API-KEY-ID"] == "pk-test" + assert headers["APCA-API-SECRET-KEY"] == "sk-test" + + +# --------------------------------------------------------------------------- +# Queue name constant +# --------------------------------------------------------------------------- + + +class TestQueueConstant: + def test_broker_queue_name(self): + assert QUEUE_BROKER == "broker_orders" + + +# --------------------------------------------------------------------------- +# Idempotency key generation tests +# --------------------------------------------------------------------------- + + +class TestGenerateIdempotencyKey: + def test_explicit_key_passthrough(self): + job = {"ticker": "AAPL", "idempotency_key": "my-explicit-key"} + assert generate_idempotency_key(job) == "my-explicit-key" + + def test_deterministic_without_explicit_key(self): + job = {"ticker": "AAPL", "side": "buy", "quantity": 10, "order_type": "market"} + key1 = generate_idempotency_key(job) + key2 = generate_idempotency_key(job) + assert key1 == key2 + assert len(key1) == 40 # sha256 truncated to 40 chars + + def test_different_jobs_produce_different_keys(self): + job_a = {"ticker": "AAPL", "side": "buy", "quantity": 10} + job_b = {"ticker": "AAPL", "side": "sell", "quantity": 10} + assert generate_idempotency_key(job_a) != generate_idempotency_key(job_b) + + def test_quantity_difference_changes_key(self): + job_a = {"ticker": "AAPL", "side": "buy", "quantity": 10} + job_b = {"ticker": "AAPL", "side": "buy", "quantity": 20} + assert generate_idempotency_key(job_a) != generate_idempotency_key(job_b) + + def test_recommendation_id_included(self): + job_a = {"ticker": "AAPL", "recommendation_id": "rec-1"} + job_b = {"ticker": "AAPL", "recommendation_id": "rec-2"} + assert generate_idempotency_key(job_a) != generate_idempotency_key(job_b) + + def test_minimal_job_still_produces_key(self): + job = {"ticker": "AAPL"} + key = generate_idempotency_key(job) + assert key + assert len(key) == 40 diff --git a/tests/test_config.py b/tests/test_config.py index a43d823..a0fce39 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,5 @@ """Basic tests for shared config loader.""" -from services.shared.config import load_config, AppConfig +from services.shared.config import load_config, AppConfig, AlertingConfig def test_load_config_returns_app_config(): @@ -20,3 +20,13 @@ def test_redis_url_format(): def test_default_broker_mode(): config = load_config() assert config.broker.mode == "paper" + + +def test_alerting_config_defaults(): + config = load_config() + assert isinstance(config.alerting, AlertingConfig) + assert config.alerting.source_failure_threshold == 3 + assert config.alerting.schema_failure_rate_threshold == 0.3 + assert config.alerting.lake_lag_threshold_minutes == 60 + assert config.alerting.broker_error_threshold == 3 + assert config.alerting.check_interval_seconds == 120 diff --git a/tests/test_content.py b/tests/test_content.py new file mode 100644 index 0000000..a4e0996 --- /dev/null +++ b/tests/test_content.py @@ -0,0 +1,84 @@ +"""Tests for shared canonical URL normalization and content hashing. + +Validates normalize_url, content_hash, and content_hash_str from +services.shared.content. + +Requirements: 3.2, 3.3 +""" +import hashlib + +from services.shared.content import content_hash, content_hash_str, normalize_url + + +class TestNormalizeUrl: + def test_lowercases_scheme_and_host(self): + assert normalize_url("HTTPS://Example.COM/path") == "https://example.com/path" + + def test_strips_trailing_slash(self): + assert normalize_url("https://example.com/path/") == "https://example.com/path" + + def test_strips_fragment(self): + result = normalize_url("https://example.com/path#section") + assert "#" not in result + assert result == "https://example.com/path" + + def test_preserves_query(self): + assert normalize_url("https://example.com/path?q=test") == "https://example.com/path?q=test" + + def test_sorts_query_params(self): + result = normalize_url("https://example.com/path?z=1&a=2") + assert result == "https://example.com/path?a=2&z=1" + + def test_preserves_non_standard_port(self): + result = normalize_url("https://example.com:8443/path") + assert ":8443" in result + + def test_strips_default_port_443(self): + result = normalize_url("https://example.com:443/path") + assert ":443" not in result + + def test_strips_default_port_80(self): + result = normalize_url("http://example.com:80/path") + assert ":80" not in result + + def test_root_path(self): + assert normalize_url("https://example.com") == "https://example.com/" + + def test_defaults_scheme_to_https(self): + result = normalize_url("//example.com/path") + assert result.startswith("https://") + + def test_deterministic_for_same_input(self): + url = "https://example.com/article?b=2&a=1#frag" + assert normalize_url(url) == normalize_url(url) + + +class TestContentHash: + def test_returns_sha256_hex(self): + data = b"hello world" + expected = hashlib.sha256(data).hexdigest() + assert content_hash(data) == expected + + def test_deterministic(self): + data = b"test content" + assert content_hash(data) == content_hash(data) + + def test_different_content_different_hash(self): + assert content_hash(b"aaa") != content_hash(b"bbb") + + def test_empty_bytes(self): + result = content_hash(b"") + assert len(result) == 64 # SHA-256 hex length + + +class TestContentHashStr: + def test_matches_manual_sha256(self): + text = "hello world" + expected = hashlib.sha256(text.encode("utf-8")).hexdigest() + assert content_hash_str(text) == expected + + def test_deterministic(self): + assert content_hash_str("test") == content_hash_str("test") + + def test_different_text_different_hash(self): + assert content_hash_str("aaa") != content_hash_str("bbb") diff --git a/tests/test_contradiction.py b/tests/test_contradiction.py new file mode 100644 index 0000000..3bca8f6 --- /dev/null +++ b/tests/test_contradiction.py @@ -0,0 +1,165 @@ +"""Tests for contradiction detection and disagreement representation. + +Requirements: 6.4, 6.5 +""" +from datetime import datetime, timezone + +from services.aggregation.contradiction import ( + CatalystEntry, + ContradictionResult, + detect_contradictions, +) +from services.aggregation.scoring import WeightedSignal, compute_signal_weight + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + +def _sw(doc_id: str, sentiment: float, impact: float = 0.5) -> WeightedSignal: + """Helper to build a WeightedSignal with default scoring.""" + w = compute_signal_weight(NOW, NOW, "7d", 0.8, extraction_confidence=0.8) + return WeightedSignal(doc_id, w, sentiment_value=sentiment, impact_score=impact) + + +# --------------------------------------------------------------------------- +# Overall score (backward compat with compute_contradiction_score) +# --------------------------------------------------------------------------- + + +def test_no_signals_returns_zero(): + result = detect_contradictions([]) + assert result.score == 0.0 + assert result.details == [] + + +def test_all_positive_no_contradiction(): + signals = [_sw("d1", 1.0), _sw("d2", 1.0)] + result = detect_contradictions(signals) + assert result.score == 0.0 + assert len(result.details) == 0 + + +def test_equal_opposing_gives_half(): + signals = [_sw("d1", 1.0, 0.5), _sw("d2", -1.0, 0.5)] + result = detect_contradictions(signals) + assert abs(result.score - 0.5) < 1e-4 + + +def test_neutral_signals_ignored(): + signals = [_sw("d1", 0.0), _sw("d2", 0.0)] + result = detect_contradictions(signals) + assert result.score == 0.0 + assert result.details == [] + + +# --------------------------------------------------------------------------- +# Sentiment disagreement detail +# --------------------------------------------------------------------------- + + +def test_sentiment_disagreement_detail_present(): + signals = [_sw("d1", 1.0, 0.6), _sw("d2", -1.0, 0.4)] + result = detect_contradictions(signals) + sentiments = [d for d in result.details if d.dimension == "sentiment"] + assert len(sentiments) == 1 + detail = sentiments[0] + assert detail.positive_doc_ids == ["d1"] + assert detail.negative_doc_ids == ["d2"] + assert detail.positive_weight > 0 + assert detail.negative_weight > 0 + assert "positive" in detail.description.lower() or "sentiment" in detail.description.lower() + + +def test_no_sentiment_detail_when_all_agree(): + signals = [_sw("d1", 1.0), _sw("d2", 1.0)] + result = detect_contradictions(signals) + sentiments = [d for d in result.details if d.dimension == "sentiment"] + assert len(sentiments) == 0 + + +# --------------------------------------------------------------------------- +# Catalyst disagreement detail +# --------------------------------------------------------------------------- + + +def test_catalyst_disagreement_detected(): + signals = [_sw("d1", 1.0, 0.7), _sw("d2", -1.0, 0.5)] + entries = [ + CatalystEntry("d1", "earnings"), + CatalystEntry("d2", "earnings"), + ] + result = detect_contradictions(signals, entries) + catalyst_details = [d for d in result.details if d.dimension.startswith("catalyst:")] + assert len(catalyst_details) == 1 + assert catalyst_details[0].dimension == "catalyst:earnings" + assert catalyst_details[0].positive_doc_ids == ["d1"] + assert catalyst_details[0].negative_doc_ids == ["d2"] + + +def test_no_catalyst_disagreement_when_same_sentiment(): + signals = [_sw("d1", 1.0), _sw("d2", 1.0)] + entries = [ + CatalystEntry("d1", "earnings"), + CatalystEntry("d2", "earnings"), + ] + result = detect_contradictions(signals, entries) + catalyst_details = [d for d in result.details if d.dimension.startswith("catalyst:")] + assert len(catalyst_details) == 0 + + +def test_catalyst_disagreement_across_types(): + """Different catalyst types with internal disagreement each get a detail.""" + signals = [ + _sw("d1", 1.0, 0.5), + _sw("d2", -1.0, 0.5), + _sw("d3", 1.0, 0.5), + _sw("d4", -1.0, 0.5), + ] + entries = [ + CatalystEntry("d1", "earnings"), + CatalystEntry("d2", "earnings"), + CatalystEntry("d3", "product"), + CatalystEntry("d4", "product"), + ] + result = detect_contradictions(signals, entries) + catalyst_details = [d for d in result.details if d.dimension.startswith("catalyst:")] + dims = {d.dimension for d in catalyst_details} + assert "catalyst:earnings" in dims + assert "catalyst:product" in dims + + +# --------------------------------------------------------------------------- +# Integration with assemble_trend_summary +# --------------------------------------------------------------------------- + + +def test_trend_summary_includes_disagreement_details(): + """assemble_trend_summary should populate disagreement_details.""" + from datetime import timedelta + + from services.aggregation.worker import ( + ImpactRow, + assemble_trend_summary, + build_weighted_signals, + ) + + impacts = [ + ImpactRow( + document_id="d1", confidence=0.8, novelty_score=0.5, + source_credibility=0.8, sentiment="positive", impact_score=0.7, + catalyst_type="earnings", key_facts=[], risks=[], + published_at=NOW - timedelta(hours=1), + ), + ImpactRow( + document_id="d2", confidence=0.8, novelty_score=0.5, + source_credibility=0.8, sentiment="negative", impact_score=0.7, + catalyst_type="earnings", key_facts=[], risks=[], + published_at=NOW - timedelta(hours=2), + ), + ] + signals = build_weighted_signals(impacts, NOW, "7d") + summary = assemble_trend_summary("AAPL", "7d", signals, impacts, reference_time=NOW) + + assert summary.contradiction_score > 0 + assert len(summary.disagreement_details) > 0 + dims = {d.dimension for d in summary.disagreement_details} + assert "sentiment" in dims diff --git a/tests/test_dead_letter.py b/tests/test_dead_letter.py new file mode 100644 index 0000000..46323a9 --- /dev/null +++ b/tests/test_dead_letter.py @@ -0,0 +1,208 @@ +"""Tests for dead-letter queue support and replay tooling.""" +from __future__ import annotations + +import json + +import pytest + +from services.shared.dead_letter import ( + DEFAULT_MAX_ATTEMPTS, + dlq_length, + dlq_summary, + peek_dlq, + purge_dlq, + replay_all, + replay_one, + send_to_dlq, + wrap_dlq_entry, +) +from services.shared.redis_keys import dlq_key, queue_key + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class FakeRedis: + """Minimal async Redis fake backed by plain dicts.""" + + def __init__(self): + self._data: dict[str, list[str]] = {} + + async def rpush(self, key: str, value: str) -> int: + self._data.setdefault(key, []).append(value) + return len(self._data[key]) + + async def lpop(self, key: str) -> str | None: + lst = self._data.get(key, []) + if not lst: + return None + return lst.pop(0) + + async def llen(self, key: str) -> int: + return len(self._data.get(key, [])) + + async def lrange(self, key: str, start: int, end: int) -> list[str]: + lst = self._data.get(key, []) + return lst[start:end + 1] + + async def delete(self, key: str) -> int: + if key in self._data: + del self._data[key] + return 1 + return 0 + + +@pytest.fixture +def rds(): + return FakeRedis() + + +SAMPLE_JOB = {"ticker": "AAPL", "source_type": "news_api", "source_id": "src-1"} + + +# --------------------------------------------------------------------------- +# wrap_dlq_entry +# --------------------------------------------------------------------------- + +def test_wrap_dlq_entry_structure(): + entry = wrap_dlq_entry(SAMPLE_JOB, "ingestion", "timeout", attempt=2, worker="ingestion_worker") + assert entry["original_payload"] == SAMPLE_JOB + assert entry["queue"] == "ingestion" + assert entry["error"] == "timeout" + assert entry["attempt"] == 2 + assert entry["worker"] == "ingestion_worker" + assert "dead_lettered_at" in entry + + +# --------------------------------------------------------------------------- +# send_to_dlq / dlq_length +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_to_dlq_and_length(rds): + await send_to_dlq(rds, "parsing", SAMPLE_JOB, error="parse failure", attempt=3) + length = await dlq_length(rds, "parsing") + assert length == 1 + + # Verify the stored entry + raw = rds._data[dlq_key("parsing")][0] + entry = json.loads(raw) + assert entry["original_payload"] == SAMPLE_JOB + assert entry["error"] == "parse failure" + assert entry["attempt"] == 3 + + +# --------------------------------------------------------------------------- +# peek_dlq +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_peek_dlq(rds): + for i in range(5): + await send_to_dlq(rds, "extraction", {"doc": i}, error=f"err-{i}") + + items = await peek_dlq(rds, "extraction", start=0, count=3) + assert len(items) == 3 + assert items[0]["original_payload"]["doc"] == 0 + assert items[2]["original_payload"]["doc"] == 2 + + # DLQ should still have all 5 items (peek doesn't remove) + assert await dlq_length(rds, "extraction") == 5 + + +# --------------------------------------------------------------------------- +# replay_one +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_replay_one(rds): + await send_to_dlq(rds, "ingestion", SAMPLE_JOB, error="timeout") + await send_to_dlq(rds, "ingestion", {"ticker": "MSFT"}, error="timeout") + + entry = await replay_one(rds, "ingestion") + assert entry is not None + assert entry["original_payload"] == SAMPLE_JOB + + # Original payload should now be in the source queue + source_queue = queue_key("ingestion") + raw = await rds.lpop(source_queue) + assert raw is not None + assert json.loads(raw) == SAMPLE_JOB + + # DLQ should have 1 remaining + assert await dlq_length(rds, "ingestion") == 1 + + +@pytest.mark.asyncio +async def test_replay_one_empty(rds): + result = await replay_one(rds, "ingestion") + assert result is None + + +# --------------------------------------------------------------------------- +# replay_all +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_replay_all(rds): + for i in range(4): + await send_to_dlq(rds, "aggregation", {"idx": i}, error="fail") + + count = await replay_all(rds, "aggregation") + assert count == 4 + + # DLQ should be empty + assert await dlq_length(rds, "aggregation") == 0 + + # Source queue should have 4 items + source_queue = queue_key("aggregation") + assert await rds.llen(source_queue) == 4 + + +@pytest.mark.asyncio +async def test_replay_all_empty(rds): + count = await replay_all(rds, "aggregation") + assert count == 0 + + +# --------------------------------------------------------------------------- +# purge_dlq +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_purge_dlq(rds): + for i in range(3): + await send_to_dlq(rds, "parsing", {"idx": i}, error="fail") + + removed = await purge_dlq(rds, "parsing") + assert removed == 3 + assert await dlq_length(rds, "parsing") == 0 + + +@pytest.mark.asyncio +async def test_purge_dlq_empty(rds): + removed = await purge_dlq(rds, "parsing") + assert removed == 0 + + +# --------------------------------------------------------------------------- +# dlq_summary +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_dlq_summary(rds): + await send_to_dlq(rds, "ingestion", {"a": 1}, error="e") + await send_to_dlq(rds, "ingestion", {"b": 2}, error="e") + await send_to_dlq(rds, "parsing", {"c": 3}, error="e") + + summary = await dlq_summary(rds, ["ingestion", "parsing", "extraction"]) + assert summary == {"ingestion": 2, "parsing": 1, "extraction": 0} + + +# --------------------------------------------------------------------------- +# DEFAULT_MAX_ATTEMPTS constant +# --------------------------------------------------------------------------- + +def test_default_max_attempts(): + assert DEFAULT_MAX_ATTEMPTS == 3 diff --git a/tests/test_dedupe.py b/tests/test_dedupe.py new file mode 100644 index 0000000..1f2fc2c --- /dev/null +++ b/tests/test_dedupe.py @@ -0,0 +1,187 @@ +"""Tests for cross-source deduplication logic. + +Validates the pure functions and key-building helpers in services.shared.dedupe. +Async functions that require Redis/PostgreSQL are tested with lightweight fakes. + +Requirements: 3.2, 3.3 +""" +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from services.shared.dedupe import ( + DedupeResult, + _hash_dedupe_key, + _url_dedupe_key, + check_duplicate, + dedupe_items, + mark_as_seen, +) +from services.shared.redis_keys import DEDUPE_PREFIX + + +class TestDedupeKeyBuilders: + def test_hash_dedupe_key_format(self): + key = _hash_dedupe_key("abc123") + assert key == f"{DEDUPE_PREFIX}:abc123" + + def test_url_dedupe_key_is_hashed(self): + key = _url_dedupe_key("https://example.com/article") + assert key.startswith(f"{DEDUPE_PREFIX}:url:") + # Should be deterministic + assert key == _url_dedupe_key("https://example.com/article") + + def test_url_dedupe_key_differs_for_different_urls(self): + k1 = _url_dedupe_key("https://a.com/1") + k2 = _url_dedupe_key("https://b.com/2") + assert k1 != k2 + + +class TestDedupeResult: + def test_not_duplicate(self): + r = DedupeResult(is_duplicate=False) + assert not r.is_duplicate + assert r.existing_document_id is None + assert r.match_type is None + + def test_duplicate_with_details(self): + r = DedupeResult( + is_duplicate=True, + existing_document_id="doc-123", + match_type="canonical_url", + ) + assert r.is_duplicate + assert r.existing_document_id == "doc-123" + + +class FakeRedis: + """Minimal async Redis fake for dedupe tests.""" + + def __init__(self, data: dict[str, str] | None = None): + self._data: dict[str, str] = data or {} + + async def get(self, key: str) -> str | None: + return self._data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> None: + self._data[key] = value + + +class FakePool: + """Minimal async PG pool fake that returns None for all queries.""" + + def __init__(self, rows: dict[str, dict | None] | None = None): + self._rows = rows or {} + + async def fetchrow(self, query: str, *args) -> dict | None: + # Match on the first arg (content_hash or canonical_url) + if args: + return self._rows.get(str(args[0])) + return None + + +@pytest.mark.asyncio +async def test_check_duplicate_no_match(): + rds = FakeRedis() + pool = FakePool() + result = await check_duplicate( + pool, rds, content_hash="newhash", url="https://example.com/new" + ) + assert not result.is_duplicate + + +@pytest.mark.asyncio +async def test_check_duplicate_redis_hash_hit(): + hash_key = _hash_dedupe_key("existinghash") + rds = FakeRedis({hash_key: "doc-abc"}) + pool = FakePool() + result = await check_duplicate(pool, rds, content_hash="existinghash") + assert result.is_duplicate + assert result.existing_document_id == "doc-abc" + assert result.match_type == "content_hash" + + +@pytest.mark.asyncio +async def test_check_duplicate_redis_url_hit(): + canonical = "https://example.com/article" + url_key = _url_dedupe_key(canonical) + rds = FakeRedis({url_key: "doc-xyz"}) + pool = FakePool() + result = await check_duplicate( + pool, rds, content_hash="newhash", canonical_url=canonical + ) + assert result.is_duplicate + assert result.existing_document_id == "doc-xyz" + assert result.match_type == "canonical_url" + + +@pytest.mark.asyncio +async def test_check_duplicate_pg_hash_fallback(): + rds = FakeRedis() + pool = FakePool({"pghash": {"id": "doc-pg1"}}) + result = await check_duplicate(pool, rds, content_hash="pghash") + assert result.is_duplicate + assert result.existing_document_id == "doc-pg1" + assert result.match_type == "content_hash" + # Should have warmed Redis cache + assert rds._data.get(_hash_dedupe_key("pghash")) == "doc-pg1" + + +@pytest.mark.asyncio +async def test_check_duplicate_pg_url_fallback(): + canonical = "https://example.com/filing" + rds = FakeRedis() + pool = FakePool({canonical: {"id": "doc-pg2"}}) + result = await check_duplicate( + pool, rds, content_hash="nomatch", canonical_url=canonical + ) + assert result.is_duplicate + assert result.existing_document_id == "doc-pg2" + assert result.match_type == "canonical_url" + + +@pytest.mark.asyncio +async def test_dedupe_items_partitions_correctly(): + """dedupe_items should split items into new and duplicate groups.""" + existing_hash = "existinghash" + hash_key = _hash_dedupe_key(existing_hash) + rds = FakeRedis({hash_key: "doc-old"}) + pool = FakePool() + + items = [ + {"title": "New Article", "content_hash": "newhash", "url": "https://a.com/1"}, + {"title": "Dup Article", "content_hash": existing_hash, "url": "https://b.com/2"}, + {"title": "Another New", "content_hash": "anothernew", "url": "https://c.com/3"}, + ] + + new, dups = await dedupe_items(pool, rds, items) + assert len(new) == 2 + assert len(dups) == 1 + assert dups[0]["title"] == "Dup Article" + assert dups[0]["_dedupe_existing_id"] == "doc-old" + + +@pytest.mark.asyncio +async def test_mark_as_seen_sets_redis_keys(): + rds = FakeRedis() + await mark_as_seen( + rds, + content_hash="hash123", + canonical_url="https://example.com/page", + document_id="doc-new", + ) + assert rds._data[_hash_dedupe_key("hash123")] == "doc-new" + assert rds._data[_url_dedupe_key("https://example.com/page")] == "doc-new" + + +@pytest.mark.asyncio +async def test_mark_as_seen_handles_none_url(): + rds = FakeRedis() + await mark_as_seen( + rds, content_hash="hash456", canonical_url=None, document_id="doc-x" + ) + assert rds._data[_hash_dedupe_key("hash456")] == "doc-x" + # No URL key should be set + assert len(rds._data) == 1 diff --git a/tests/test_evidence_ranking.py b/tests/test_evidence_ranking.py new file mode 100644 index 0000000..1fa59ee --- /dev/null +++ b/tests/test_evidence_ranking.py @@ -0,0 +1,136 @@ +"""Tests for evidence ranking — composite scoring for supporting/opposing docs. + +Requirements: 6.5 +""" +from datetime import datetime, timedelta, timezone + +from services.aggregation.evidence import ( + EvidenceRankConfig, + compute_evidence_rank, + rank_evidence, + rank_evidence_detailed, +) +from services.aggregation.scoring import WeightedSignal, compute_signal_weight + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + +def _sw( + doc_id: str = "doc-1", + sentiment: float = 1.0, + impact: float = 0.7, + credibility: float = 0.8, + confidence: float = 0.8, + age_hours: float = 1.0, +) -> WeightedSignal: + published = NOW - timedelta(hours=age_hours) + weight = compute_signal_weight( + published_at=published, + reference_time=NOW, + window="7d", + source_credibility=credibility, + extraction_confidence=confidence, + ) + return WeightedSignal( + document_id=doc_id, + weight=weight, + sentiment_value=sentiment, + impact_score=impact, + ) + + +# --------------------------------------------------------------------------- +# compute_evidence_rank +# --------------------------------------------------------------------------- + + +def test_rank_score_positive(): + sig = _sw("d1", sentiment=1.0, impact=0.9, credibility=1.0) + ranked = compute_evidence_rank(sig) + assert ranked.rank_score > 0 + assert ranked.document_id == "d1" + assert ranked.sentiment_value == 1.0 + + +def test_higher_impact_ranks_higher(): + low = _sw("low", impact=0.3) + high = _sw("high", impact=0.9) + r_low = compute_evidence_rank(low) + r_high = compute_evidence_rank(high) + assert r_high.rank_score > r_low.rank_score + + +def test_fresher_doc_ranks_higher(): + old = _sw("old", age_hours=100.0) + fresh = _sw("fresh", age_hours=1.0) + r_old = compute_evidence_rank(old) + r_fresh = compute_evidence_rank(fresh) + assert r_fresh.rank_score > r_old.rank_score + + +def test_higher_credibility_ranks_higher(): + low_cred = _sw("low", credibility=0.2) + high_cred = _sw("high", credibility=1.0) + r_low = compute_evidence_rank(low_cred) + r_high = compute_evidence_rank(high_cred) + assert r_high.rank_score > r_low.rank_score + + +# --------------------------------------------------------------------------- +# rank_evidence +# --------------------------------------------------------------------------- + + +def test_rank_evidence_separates_sides(): + signals = [ + _sw("pos1", sentiment=1.0, impact=0.9), + _sw("pos2", sentiment=1.0, impact=0.3), + _sw("neg1", sentiment=-1.0, impact=0.7), + _sw("neutral", sentiment=0.0, impact=0.5), + ] + supporting, opposing = rank_evidence(signals) + assert "pos1" in supporting + assert "pos2" in supporting + assert "neg1" in opposing + assert "neutral" not in supporting and "neutral" not in opposing + + +def test_rank_evidence_ordered_by_composite(): + signals = [ + _sw("weak", sentiment=1.0, impact=0.2, credibility=0.3), + _sw("strong", sentiment=1.0, impact=0.9, credibility=1.0), + ] + supporting, _ = rank_evidence(signals) + assert supporting[0] == "strong" + + +def test_rank_evidence_respects_max_refs(): + signals = [_sw(f"d{i}", sentiment=1.0) for i in range(20)] + cfg = EvidenceRankConfig(max_refs=3) + supporting, opposing = rank_evidence(signals, config=cfg) + assert len(supporting) == 3 + assert len(opposing) == 0 + + +def test_rank_evidence_empty(): + supporting, opposing = rank_evidence([]) + assert supporting == [] + assert opposing == [] + + +# --------------------------------------------------------------------------- +# rank_evidence_detailed +# --------------------------------------------------------------------------- + + +def test_detailed_returns_ranked_evidence_objects(): + signals = [ + _sw("pos1", sentiment=1.0, impact=0.9), + _sw("neg1", sentiment=-1.0, impact=0.7), + ] + sup, opp = rank_evidence_detailed(signals) + assert len(sup) == 1 + assert sup[0].document_id == "pos1" + assert sup[0].rank_score > 0 + assert len(opp) == 1 + assert opp[0].document_id == "neg1" diff --git a/tests/test_extractor_metrics.py b/tests/test_extractor_metrics.py new file mode 100644 index 0000000..1dec0e5 --- /dev/null +++ b/tests/test_extractor_metrics.py @@ -0,0 +1,168 @@ +"""Tests for extraction model performance metrics collection. + +Validates that collect_metrics correctly computes metrics from +ExtractionResponse objects for both successful and failed extractions. + +Requirements: 5.2, 5.4, 12.1, 12.2 +""" +from __future__ import annotations + +from services.extractor.client import ExtractionAttempt, ExtractionResponse +from services.extractor.metrics import collect_metrics +from services.extractor.schemas import ExtractionResult, ValidationReport + + +def _make_valid_result() -> ExtractionResult: + return ExtractionResult.model_validate({ + "summary": "Apple beat earnings expectations.", + "companies": [ + { + "ticker": "AAPL", + "company_name": "Apple Inc.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.7, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": ["Revenue up 12%"], + "risks": [], + "evidence_spans": ["Apple beat expectations"], + } + ], + "macro_themes": ["ai_capex"], + "novelty_score": 0.6, + "confidence": 0.85, + "extraction_warnings": [], + }) + + +def _make_success_response() -> ExtractionResponse: + result = _make_valid_result() + validation = ValidationReport(valid=True, errors=[], warnings=["low_novelty"], parsed=result) + attempt = ExtractionAttempt( + raw_output=result.model_dump_json(), + validation=validation, + error=None, + duration_ms=500, + model="test-model", + ) + return ExtractionResponse( + success=True, + result=result, + attempts=[attempt], + prompt_metadata={"prompt_version": "document-intel-v1", "schema_version": "2.0.0"}, + model="test-model", + total_duration_ms=500, + ) + + +def _make_failed_response_with_retries() -> ExtractionResponse: + attempt1 = ExtractionAttempt( + raw_output="bad json", + validation=None, + error="invalid_json", + duration_ms=200, + model="test-model", + ) + attempt2 = ExtractionAttempt( + raw_output="still bad output here", + validation=ValidationReport( + valid=False, + errors=["schema_fail", "missing_companies"], + warnings=["truncated"], + ), + error="schema_fail; missing_companies", + duration_ms=300, + model="test-model", + ) + return ExtractionResponse( + success=False, + result=None, + attempts=[attempt1, attempt2], + prompt_metadata={"prompt_version": "document-intel-v1", "schema_version": "2.0.0"}, + model="test-model", + total_duration_ms=500, + ) + + +def test_collect_metrics_success(): + """Successful extraction produces correct metrics.""" + resp = _make_success_response() + m = collect_metrics( + resp, + document_id="doc-1", + ticker="AAPL", + document_text_length=4000, + ) + + assert m.document_id == "doc-1" + assert m.ticker == "AAPL" + assert m.model_name == "test-model" + assert m.prompt_version == "document-intel-v1" + assert m.schema_version == "2.0.0" + assert m.success is True + assert m.attempt_count == 1 + assert m.total_duration_ms == 500 + assert m.first_attempt_duration_ms == 500 + assert m.final_attempt_duration_ms == 500 + assert m.confidence == 0.85 + assert m.validation_status == "valid" + assert m.validation_error_count == 0 + assert m.validation_warning_count == 1 + assert m.retry_count == 0 + assert m.input_token_estimate == 1000 # 4000 / 4 + assert m.output_token_estimate > 0 + assert m.company_count == 1 + + +def test_collect_metrics_failed_with_retries(): + """Failed extraction with retries produces correct metrics.""" + resp = _make_failed_response_with_retries() + m = collect_metrics( + resp, + document_id="doc-2", + ticker="MSFT", + document_text_length=2000, + ) + + assert m.success is False + assert m.attempt_count == 2 + assert m.retry_count == 1 + assert m.first_attempt_duration_ms == 200 + assert m.final_attempt_duration_ms == 300 + assert m.total_duration_ms == 500 + assert m.validation_status == "failed" + assert m.validation_error_count == 2 + assert m.validation_warning_count == 1 + assert "schema_fail" in m.validation_errors + assert m.confidence == 0.0 + assert m.company_count == 0 + assert m.input_token_estimate == 500 # 2000 / 4 + + +def test_collect_metrics_empty_attempts(): + """Response with no attempts produces safe defaults.""" + resp = ExtractionResponse( + success=False, + result=None, + attempts=[], + prompt_metadata={}, + model="test-model", + total_duration_ms=0, + ) + m = collect_metrics(resp, document_id="doc-3") + + assert m.attempt_count == 0 + assert m.retry_count == 0 + assert m.first_attempt_duration_ms == 0 + assert m.final_attempt_duration_ms == 0 + assert m.validation_status == "unknown" + assert m.confidence == 0.0 + + +def test_collect_metrics_no_document_text_length(): + """Zero document text length produces zero token estimate.""" + resp = _make_success_response() + m = collect_metrics(resp, document_text_length=0) + + assert m.input_token_estimate == 0 diff --git a/tests/test_extractor_prompts.py b/tests/test_extractor_prompts.py new file mode 100644 index 0000000..c260478 --- /dev/null +++ b/tests/test_extractor_prompts.py @@ -0,0 +1,120 @@ +"""Tests for extraction prompt templates.""" +import json + +from services.extractor.prompts import ( + EXTRACTION_JSON_SCHEMA, + PROMPT_VERSION, + SCHEMA_VERSION, + SYSTEM_PROMPT, + build_extraction_prompt, + get_json_schema, + get_prompt_metadata, +) +from services.shared.schemas import CatalystType, DocumentType, Sentiment + + +def test_build_extraction_prompt_basic(): + """Prompt contains system and user keys with document text embedded.""" + result = build_extraction_prompt( + document_text="Apple reported record Q4 earnings.", + document_type=DocumentType.ARTICLE, + ) + assert "system" in result + assert "user" in result + assert "Apple reported record Q4 earnings." in result["user"] + assert "DOCUMENT TEXT" in result["user"] + + +def test_system_prompt_has_anti_hallucination_rules(): + """System prompt includes key anti-hallucination instructions.""" + assert "NEVER fabricate" in SYSTEM_PROMPT + assert "NEVER infer" in SYSTEM_PROMPT + assert "verbatim quotes" in SYSTEM_PROMPT + assert "ONLY extract information explicitly stated" in SYSTEM_PROMPT + assert "insufficient_content" in SYSTEM_PROMPT + + +def test_build_prompt_includes_json_schema(): + """User prompt embeds the full JSON schema for structured output.""" + result = build_extraction_prompt(document_text="test", document_type=DocumentType.ARTICLE) + # Schema should be serialized into the user prompt + assert '"summary"' in result["user"] + assert '"companies"' in result["user"] + assert '"evidence_spans"' in result["user"] + + +def test_build_prompt_with_known_tickers(): + """Known tickers are included as hints but with a warning not to force-include them.""" + result = build_extraction_prompt( + document_text="Some text", + document_type=DocumentType.ARTICLE, + known_tickers=["AAPL", "MSFT"], + ) + assert "AAPL" in result["user"] + assert "MSFT" in result["user"] + assert "Do NOT include a ticker just because" in result["user"] + + +def test_build_prompt_without_tickers(): + """When no tickers are provided, no ticker hint appears.""" + result = build_extraction_prompt(document_text="Some text", document_type=DocumentType.ARTICLE) + assert "may be referenced" not in result["user"] + + +def test_build_prompt_document_type_guidance(): + """Each document type gets specific guidance in the prompt.""" + for dtype in DocumentType: + result = build_extraction_prompt(document_text="text", document_type=dtype) + assert "Document type:" in result["user"] + + +def test_build_prompt_filing_guidance(): + """Filing documents get SEC-specific guidance.""" + result = build_extraction_prompt(document_text="text", document_type=DocumentType.FILING) + assert "regulatory filing" in result["user"] + + +def test_build_prompt_transcript_guidance(): + """Transcript documents get earnings-call-specific guidance.""" + result = build_extraction_prompt(document_text="text", document_type=DocumentType.TRANSCRIPT) + assert "forward-looking" in result["user"] + + +def test_build_prompt_with_document_id(): + """Document ID is included in the prompt when provided.""" + result = build_extraction_prompt( + document_text="text", + document_type=DocumentType.ARTICLE, + document_id="abc-123", + ) + assert "abc-123" in result["user"] + + +def test_get_prompt_metadata(): + """Metadata returns current prompt and schema versions.""" + meta = get_prompt_metadata() + assert meta["prompt_version"] == PROMPT_VERSION + assert meta["schema_version"] == SCHEMA_VERSION + + +def test_get_json_schema_is_valid(): + """JSON schema has required top-level structure.""" + schema = get_json_schema() + assert schema["type"] == "object" + assert "summary" in schema["required"] + assert "companies" in schema["required"] + assert "confidence" in schema["required"] + + +def test_json_schema_enum_values_match_pydantic(): + """Schema enum values match the Pydantic enum definitions.""" + company_props = EXTRACTION_JSON_SCHEMA["properties"]["companies"]["items"]["properties"] + assert set(company_props["sentiment"]["enum"]) == {s.value for s in Sentiment} + assert set(company_props["catalyst_type"]["enum"]) == {c.value for c in CatalystType} + + +def test_json_schema_is_serializable(): + """Schema can be serialized to JSON without errors.""" + serialized = json.dumps(EXTRACTION_JSON_SCHEMA) + parsed = json.loads(serialized) + assert parsed["type"] == "object" diff --git a/tests/test_extractor_schemas.py b/tests/test_extractor_schemas.py new file mode 100644 index 0000000..371430e --- /dev/null +++ b/tests/test_extractor_schemas.py @@ -0,0 +1,317 @@ +"""Tests for extractor JSON schema definitions and validation.""" +import json + +from services.extractor.schemas import ( + SCHEMA_VERSION, + VALID_IMPACT_HORIZONS, + ExtractionResult, + generate_json_schema, + get_schema_version, + validate_extraction, +) +from services.shared.schemas import CatalystType, Sentiment + + +def test_generate_json_schema_top_level_structure(): + """Generated schema is a valid JSON Schema object with required fields.""" + schema = generate_json_schema() + assert schema["type"] == "object" + assert "summary" in schema["required"] + assert "companies" in schema["required"] + assert "confidence" in schema["required"] + assert "extraction_warnings" in schema["required"] + + +def test_generate_json_schema_no_refs(): + """Generated schema has no $ref or $defs — fully inlined.""" + schema = generate_json_schema() + serialized = json.dumps(schema) + assert "$ref" not in serialized + assert "$defs" not in serialized + + +def test_generate_json_schema_serializable(): + """Schema round-trips through JSON serialization.""" + schema = generate_json_schema() + text = json.dumps(schema) + parsed = json.loads(text) + assert parsed["type"] == "object" + + +def test_generate_json_schema_company_properties(): + """Company items include all required extraction fields.""" + schema = generate_json_schema() + company_schema = schema["properties"]["companies"]["items"] + required = company_schema["required"] + assert "ticker" in required + assert "sentiment" in required + assert "catalyst_type" in required + assert "evidence_spans" in required + + +def test_generate_json_schema_enum_values(): + """Enum values in the schema match the Pydantic enum definitions.""" + schema = generate_json_schema() + company_props = schema["properties"]["companies"]["items"]["properties"] + sentiment_vals = _extract_enum_values(company_props["sentiment"]) + catalyst_vals = _extract_enum_values(company_props["catalyst_type"]) + assert set(sentiment_vals) == {s.value for s in Sentiment} + assert set(catalyst_vals) == {c.value for c in CatalystType} + + +def test_get_schema_version(): + assert get_schema_version() == SCHEMA_VERSION + + +# --- Validation tests --- + + +def _valid_extraction() -> dict: + """Minimal valid extraction result.""" + return { + "summary": "Apple beat earnings expectations.", + "companies": [ + { + "ticker": "AAPL", + "company_name": "Apple Inc.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.7, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": ["Revenue up 12%"], + "risks": [], + "evidence_spans": ["Apple beat expectations"], + } + ], + "macro_themes": ["ai_capex"], + "novelty_score": 0.6, + "confidence": 0.85, + "extraction_warnings": [], + } + + +def test_validate_extraction_valid_dict(): + report = validate_extraction(_valid_extraction()) + assert report.valid + assert report.parsed is not None + assert report.parsed.companies[0].ticker == "AAPL" + + +def test_validate_extraction_valid_json_string(): + report = validate_extraction(json.dumps(_valid_extraction())) + assert report.valid + assert report.parsed is not None + + +def test_validate_extraction_invalid_json(): + report = validate_extraction("{bad json") + assert not report.valid + assert any("Invalid JSON" in e for e in report.errors) + + +def test_validate_extraction_not_object(): + report = validate_extraction("[1, 2, 3]") + assert not report.valid + assert any("object" in e.lower() for e in report.errors) + + +def test_validate_extraction_missing_required_field(): + data = _valid_extraction() + del data["summary"] + report = validate_extraction(data) + assert not report.valid + + +def test_validate_extraction_invalid_enum(): + data = _valid_extraction() + data["companies"][0]["sentiment"] = "super_bullish" + report = validate_extraction(data) + assert not report.valid + + +def test_validate_extraction_out_of_range(): + data = _valid_extraction() + data["confidence"] = 1.5 + report = validate_extraction(data) + assert not report.valid + + +def test_validate_semantic_empty_summary_warning(): + data = _valid_extraction() + data["summary"] = "" + report = validate_extraction(data) + assert report.valid + assert "empty_summary" in report.warnings + + +def test_validate_semantic_low_confidence_with_companies(): + data = _valid_extraction() + data["confidence"] = 0.2 + report = validate_extraction(data) + assert report.valid + assert "low_confidence_with_companies" in report.warnings + + +def test_validate_semantic_no_evidence_spans(): + data = _valid_extraction() + data["companies"][0]["evidence_spans"] = [] + report = validate_extraction(data) + assert report.valid + assert any("no_evidence_spans" in w for w in report.warnings) + + +def test_validate_semantic_high_impact_no_facts(): + data = _valid_extraction() + data["companies"][0]["key_facts"] = [] + data["companies"][0]["impact_score"] = 0.8 + report = validate_extraction(data) + assert report.valid + assert any("high_impact_no_facts" in w for w in report.warnings) + + +def test_extraction_result_model_roundtrip(): + """ExtractionResult can be created and serialized back to dict.""" + result = ExtractionResult( + summary="Test", + companies=[], + macro_themes=[], + novelty_score=0.5, + confidence=0.5, + extraction_warnings=[], + ) + data = result.model_dump() + assert data["summary"] == "Test" + reparsed = ExtractionResult.model_validate(data) + assert reparsed.summary == "Test" + + +def test_all_top_level_fields_required(): + """All top-level fields appear in the schema's required list.""" + schema = generate_json_schema() + required = set(schema["required"]) + expected = {"summary", "companies", "macro_themes", "novelty_score", "confidence", "extraction_warnings"} + assert expected == required + + +def test_all_company_fields_required(): + """All company item fields appear in the schema's required list.""" + schema = generate_json_schema() + company_required = set(schema["properties"]["companies"]["items"]["required"]) + expected = { + "ticker", "company_name", "relevance", "sentiment", + "impact_score", "impact_horizon", "catalyst_type", + "key_facts", "risks", "evidence_spans", + } + assert expected == company_required + + +# --- Semantic validation: error-level checks --- + + +def test_validate_semantic_missing_ticker_is_error(): + """A company with an empty ticker produces a semantic error, not just a warning.""" + data = _valid_extraction() + data["companies"][0]["ticker"] = "" + report = validate_extraction(data) + assert not report.valid + assert any("company_missing_ticker" in e for e in report.errors) + + +def test_validate_semantic_invalid_impact_horizon_is_error(): + """An unrecognized impact_horizon produces a semantic error.""" + data = _valid_extraction() + data["companies"][0]["impact_horizon"] = "forever" + report = validate_extraction(data) + assert not report.valid + assert any("invalid_impact_horizon" in e for e in report.errors) + + +def test_validate_semantic_all_valid_horizons_accepted(): + """Every value in VALID_IMPACT_HORIZONS passes validation.""" + for horizon in VALID_IMPACT_HORIZONS: + data = _valid_extraction() + data["companies"][0]["impact_horizon"] = horizon + report = validate_extraction(data) + assert report.valid, f"Horizon {horizon!r} should be valid" + + +def test_validate_semantic_duplicate_ticker_is_error(): + """Two company entries with the same ticker produce a semantic error.""" + data = _valid_extraction() + second = dict(data["companies"][0]) + data["companies"].append(second) + report = validate_extraction(data) + assert not report.valid + assert any("duplicate_ticker_AAPL" in e for e in report.errors) + + +# --- Semantic validation: warning-level checks --- + + +def test_validate_semantic_invalid_ticker_format_warning(): + """A lowercase or overly long ticker produces a warning.""" + data = _valid_extraction() + data["companies"][0]["ticker"] = "aapl" + report = validate_extraction(data) + assert report.valid # warning, not error + assert any("invalid_ticker_format" in w for w in report.warnings) + + +def test_validate_semantic_evidence_span_too_short(): + data = _valid_extraction() + data["companies"][0]["evidence_spans"] = ["short"] + report = validate_extraction(data) + assert report.valid + assert any("evidence_span_too_short" in w for w in report.warnings) + + +def test_validate_semantic_evidence_span_too_long(): + data = _valid_extraction() + data["companies"][0]["evidence_spans"] = ["x" * 501] + report = validate_extraction(data) + assert report.valid + assert any("evidence_span_too_long" in w for w in report.warnings) + + +def test_validate_semantic_strong_sentiment_low_impact(): + data = _valid_extraction() + data["companies"][0]["sentiment"] = "positive" + data["companies"][0]["impact_score"] = 0.05 + report = validate_extraction(data) + assert report.valid + assert any("strong_sentiment_low_impact" in w for w in report.warnings) + + +# --- Evidence grounding --- + + +def test_validate_evidence_grounding_found(): + """Evidence spans present in document_text produce no grounding warnings.""" + data = _valid_extraction() + doc_text = "Apple beat expectations with record revenue." + report = validate_extraction(data, document_text=doc_text) + assert report.valid + assert not any("evidence_span_not_found" in w for w in report.warnings) + + +def test_validate_evidence_grounding_not_found(): + """Evidence spans NOT in document_text produce a grounding warning.""" + data = _valid_extraction() + doc_text = "Completely unrelated document about weather." + report = validate_extraction(data, document_text=doc_text) + assert report.valid + assert any("evidence_span_not_found" in w for w in report.warnings) + + +# --- Helpers --- + + +def _extract_enum_values(prop: dict) -> list: + """Extract enum values from a JSON schema property, handling anyOf patterns.""" + if "enum" in prop: + return prop["enum"] + for option in prop.get("anyOf", []): + if "enum" in option: + return option["enum"] + return [] diff --git a/tests/test_extractor_worker.py b/tests/test_extractor_worker.py new file mode 100644 index 0000000..1b165b7 --- /dev/null +++ b/tests/test_extractor_worker.py @@ -0,0 +1,200 @@ +"""Tests for the extraction worker persistence logic. + +Validates that persist_extraction correctly uploads artifacts to MinIO +and persists intelligence/impact records to PostgreSQL. + +Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2 +""" +from __future__ import annotations + +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.extractor.client import ExtractionAttempt, ExtractionResponse +from services.extractor.schemas import ExtractionResult, ValidationReport +from services.extractor.worker import persist_extraction + + +def _make_valid_result() -> ExtractionResult: + """Build a minimal valid ExtractionResult.""" + return ExtractionResult.model_validate({ + "summary": "Apple beat earnings expectations.", + "companies": [ + { + "ticker": "AAPL", + "company_name": "Apple Inc.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.7, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": ["Revenue up 12%"], + "risks": [], + "evidence_spans": ["Apple beat expectations"], + } + ], + "macro_themes": ["ai_capex"], + "novelty_score": 0.6, + "confidence": 0.85, + "extraction_warnings": [], + }) + + +def _make_success_response() -> ExtractionResponse: + """Build a successful ExtractionResponse with one attempt.""" + result = _make_valid_result() + validation = ValidationReport(valid=True, errors=[], warnings=[], parsed=result) + attempt = ExtractionAttempt( + raw_output=result.model_dump_json(), + validation=validation, + error=None, + duration_ms=500, + model="test-model", + ) + return ExtractionResponse( + success=True, + result=result, + attempts=[attempt], + prompt_metadata={"prompt_version": "document-intel-v1", "schema_version": "2.0.0"}, + model="test-model", + total_duration_ms=500, + ) + + +def _make_failed_response() -> ExtractionResponse: + """Build a failed ExtractionResponse with two attempts.""" + attempt1 = ExtractionAttempt( + raw_output="bad json", + validation=None, + error="invalid_json", + duration_ms=200, + model="test-model", + ) + attempt2 = ExtractionAttempt( + raw_output="still bad", + validation=ValidationReport(valid=False, errors=["schema_fail"], warnings=[]), + error="schema_fail", + duration_ms=300, + model="test-model", + ) + return ExtractionResponse( + success=False, + result=None, + attempts=[attempt1, attempt2], + prompt_metadata={"prompt_version": "document-intel-v1", "schema_version": "2.0.0"}, + model="test-model", + total_duration_ms=500, + ) + + +def _mock_pool(intel_id: str = "intel-uuid-1", impact_id: str = "impact-uuid-1") -> AsyncMock: + """Create a mock asyncpg pool that returns predictable UUIDs.""" + pool = AsyncMock() + # Side effects: intelligence insert, impact insert, metrics insert + pool.fetchval = AsyncMock(side_effect=[intel_id, impact_id, "metrics-uuid-1"]) + pool.execute = AsyncMock() + return pool + + +def _mock_minio() -> MagicMock: + """Create a mock MinIO client.""" + client = MagicMock() + return client + + +@pytest.mark.asyncio +async def test_persist_successful_extraction(): + """Successful extraction persists all artifacts and intelligence records.""" + pool = _mock_pool() + minio = _mock_minio() + response = _make_success_response() + ts = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + result = await persist_extraction( + pool=pool, + minio_client=minio, + document_id="doc-123", + ticker="AAPL", + extraction_response=response, + company_id_map={"AAPL": "company-uuid-1"}, + source_credibility=0.8, + timestamp=ts, + ) + + assert result.success + assert result.intelligence_id == "intel-uuid-1" + assert result.impact_ids == ["impact-uuid-1"] + assert result.prompt_ref is not None + assert "stonks-llm-prompts" in result.prompt_ref + assert result.raw_output_ref is not None + assert "stonks-llm-results" in result.raw_output_ref + assert result.validation_ref is not None + assert result.intelligence_ref is not None + + # MinIO should have 4 uploads: prompt, raw output, validation, intelligence + assert minio.put_object.call_count == 4 + + # PostgreSQL: 1 intelligence insert + 1 impact insert + 1 metrics insert + 1 status update + assert pool.fetchval.call_count == 3 + assert pool.execute.call_count == 1 + + +@pytest.mark.asyncio +async def test_persist_failed_extraction(): + """Failed extraction still persists attempt data and marks document as failed.""" + pool = AsyncMock() + pool.fetchval = AsyncMock(return_value="intel-uuid-fail") + pool.execute = AsyncMock() + minio = _mock_minio() + response = _make_failed_response() + ts = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + result = await persist_extraction( + pool=pool, + minio_client=minio, + document_id="doc-456", + ticker="AAPL", + extraction_response=response, + timestamp=ts, + ) + + assert not result.success + assert result.intelligence_id == "intel-uuid-fail" + assert result.intelligence_ref is None # no final intelligence on failure + assert result.prompt_ref is not None + assert result.raw_output_ref is not None + assert result.validation_ref is not None + + # MinIO: 3 uploads (prompt, raw output, validation — no intelligence) + assert minio.put_object.call_count == 3 + + # PostgreSQL: 1 intelligence insert + 1 metrics insert + 1 status update + assert pool.fetchval.call_count == 2 + assert pool.execute.call_count == 1 + + +@pytest.mark.asyncio +async def test_persist_skips_impact_without_company_id(): + """Impact records are skipped when company_id_map doesn't have the ticker.""" + pool = AsyncMock() + pool.fetchval = AsyncMock(return_value="intel-uuid-2") + pool.execute = AsyncMock() + minio = _mock_minio() + response = _make_success_response() + + result = await persist_extraction( + pool=pool, + minio_client=minio, + document_id="doc-789", + ticker="AAPL", + extraction_response=response, + company_id_map={}, # no mapping for AAPL + ) + + assert result.success + assert result.impact_ids == [] + # 1 fetchval for intelligence + 1 for metrics, no impact insert + assert pool.fetchval.call_count == 2 diff --git a/tests/test_fail_closed_broker.py b/tests/test_fail_closed_broker.py new file mode 100644 index 0000000..8ca593b --- /dev/null +++ b/tests/test_fail_closed_broker.py @@ -0,0 +1,332 @@ +"""Validate fail-closed behavior for broker outages and ambiguous order states. + +Tests that the system rejects orders rather than risking duplicates or +ambiguous execution when the broker API is unavailable, returns errors, +times out, or returns unexpected/ambiguous responses. + +Requirements: 8.4, 8.5, N5 +Design: Section 10 - Reliability and Safety +""" +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from services.adapters.broker_adapter import ( + AlpacaBrokerAdapter, + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + TradingMode, +) +from services.risk.engine import ( + AccountRiskState, + DailyLossLimits, + PortfolioRiskConfig, + PositionLimits, + ProposedOrder, + RiskCheckResult, + TradingMode as RiskTradingMode, + evaluate_order, +) + +NOW = datetime(2026, 4, 11, 14, 0, 0, tzinfo=timezone.utc) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_adapter(base_url: str = "https://paper-api.alpaca.markets") -> AlpacaBrokerAdapter: + return AlpacaBrokerAdapter( + api_key="test-key", + api_secret="test-secret", + mode=TradingMode.PAPER, + base_url=base_url, + ) + + +def _make_buy_order(ticker: str = "AAPL", qty: float = 10) -> OrderRequest: + return OrderRequest( + ticker=ticker, + side=OrderSide.BUY, + quantity=qty, + order_type=OrderType.MARKET, + idempotency_key=f"test-{ticker}-{qty}", + ) + + +# --------------------------------------------------------------------------- +# 1. Broker network outage — submit_order fails closed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestSubmitOrderFailsClosed: + """submit_order must return REJECTED on any network/transport error.""" + + async def test_connection_error_returns_rejected(self): + adapter = _make_adapter() + order = _make_buy_order() + + with patch("httpx.AsyncClient.post", side_effect=httpx.ConnectError("connection refused")): + resp = await adapter.submit_order(order) + + assert resp.status == OrderStatus.REJECTED + assert resp.ok is False + assert "fail-closed" in resp.error + + async def test_timeout_returns_rejected(self): + adapter = _make_adapter() + order = _make_buy_order() + + with patch("httpx.AsyncClient.post", side_effect=httpx.ReadTimeout("read timed out")): + resp = await adapter.submit_order(order) + + assert resp.status == OrderStatus.REJECTED + assert resp.ok is False + assert "fail-closed" in resp.error + + async def test_dns_error_returns_rejected(self): + adapter = _make_adapter() + order = _make_buy_order() + + with patch("httpx.AsyncClient.post", side_effect=httpx.ConnectError("DNS resolution failed")): + resp = await adapter.submit_order(order) + + assert resp.status == OrderStatus.REJECTED + assert "fail-closed" in resp.error + + async def test_http_500_returns_rejected(self): + """Broker internal server error should result in rejection.""" + adapter = _make_adapter() + order = _make_buy_order() + + mock_resp = httpx.Response(500, text="Internal Server Error", request=httpx.Request("POST", "http://test")) + with patch("httpx.AsyncClient.post", side_effect=httpx.HTTPStatusError("500", response=mock_resp, request=mock_resp.request)): + resp = await adapter.submit_order(order) + + assert resp.status == OrderStatus.REJECTED + assert resp.ok is False + assert resp.broker_order_id == "" + + async def test_http_503_returns_rejected(self): + """Broker service unavailable should result in rejection.""" + adapter = _make_adapter() + order = _make_buy_order() + + mock_resp = httpx.Response(503, text="Service Unavailable", request=httpx.Request("POST", "http://test")) + with patch("httpx.AsyncClient.post", side_effect=httpx.HTTPStatusError("503", response=mock_resp, request=mock_resp.request)): + resp = await adapter.submit_order(order) + + assert resp.status == OrderStatus.REJECTED + assert resp.ok is False + + async def test_rejected_order_has_empty_broker_id(self): + """Fail-closed responses must not carry a broker order ID that could be confused with a real order.""" + adapter = _make_adapter() + order = _make_buy_order() + + with patch("httpx.AsyncClient.post", side_effect=Exception("unexpected")): + resp = await adapter.submit_order(order) + + assert resp.broker_order_id == "" + assert resp.filled_quantity == 0 + assert resp.filled_avg_price is None + + +# --------------------------------------------------------------------------- +# 2. Ambiguous order states — get_order_status fails closed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGetOrderStatusFailsClosed: + """get_order_status must return REJECTED on errors, not an ambiguous state.""" + + async def test_network_error_returns_rejected(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.get", side_effect=httpx.ConnectError("refused")): + resp = await adapter.get_order_status("order-123") + + assert resp.status == OrderStatus.REJECTED + assert resp.error is not None + + async def test_timeout_returns_rejected(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.get", side_effect=httpx.ReadTimeout("timeout")): + resp = await adapter.get_order_status("order-123") + + assert resp.status == OrderStatus.REJECTED + assert resp.error is not None + + +# --------------------------------------------------------------------------- +# 3. Cancel order fails closed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCancelOrderFailsClosed: + """cancel_order must return REJECTED on errors rather than leaving order in unknown state.""" + + async def test_network_error_returns_rejected(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.delete", side_effect=httpx.ConnectError("refused")): + resp = await adapter.cancel_order("order-456") + + assert resp.status == OrderStatus.REJECTED + assert resp.error is not None + + async def test_timeout_returns_rejected(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.delete", side_effect=httpx.ReadTimeout("timeout")): + resp = await adapter.cancel_order("order-456") + + assert resp.status == OrderStatus.REJECTED + + +# --------------------------------------------------------------------------- +# 4. Position and account queries degrade safely +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestPositionAccountDegradation: + """Position/account queries must return safe defaults on broker outage.""" + + async def test_get_positions_returns_empty_on_outage(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.get", side_effect=httpx.ConnectError("refused")): + positions = await adapter.get_positions() + + assert positions == [] + + async def test_get_account_returns_zeroed_on_outage(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.get", side_effect=httpx.ConnectError("refused")): + acct = await adapter.get_account() + + assert acct.buying_power == 0 + assert acct.cash == 0 + assert acct.portfolio_value == 0 + assert acct.account_id == "" + + +# --------------------------------------------------------------------------- +# 5. Risk engine fails closed with degraded state +# --------------------------------------------------------------------------- + + +class TestRiskEngineFailClosed: + """Risk engine must reject orders when account state is missing or degraded.""" + + def test_zero_portfolio_value_blocks_buy(self): + """If broker is down and portfolio_value is 0, position pct → 1.0 → fail.""" + config = PortfolioRiskConfig() + state = AccountRiskState(portfolio_value=0.0, cash=0.0) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=1000, quantity=10, + ) + result = evaluate_order(order, config, state) + assert not result.passed + pct_check = next(c for c in result.checks if c.check_name == "max_position_pct") + assert pct_check.result == RiskCheckResult.FAIL + assert pct_check.actual == 1.0 + + def test_disabled_mode_blocks_all_orders(self): + config = PortfolioRiskConfig(trading_mode=RiskTradingMode.DISABLED) + state = AccountRiskState(portfolio_value=100_000.0, cash=50_000.0) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=1000, quantity=10, + ) + result = evaluate_order(order, config, state) + assert not result.passed + assert any("disabled" in r.lower() for r in result.rejection_reasons) + + def test_degraded_state_with_zero_buying_power(self): + """When broker returns zeroed account, position value check should still block large orders.""" + config = PortfolioRiskConfig( + position_limits=PositionLimits(max_position_value=5_000.0), + ) + state = AccountRiskState( + portfolio_value=0.0, cash=0.0, buying_power=0.0, + ) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=10_000.0, quantity=50, + ) + result = evaluate_order(order, config, state) + assert not result.passed + + def test_multiple_risk_failures_all_captured_on_degraded_state(self): + """Degraded state should trigger multiple failures, all recorded for audit.""" + config = PortfolioRiskConfig( + position_limits=PositionLimits(max_position_value=500), + daily_loss=DailyLossLimits(max_daily_loss_value=0), + ) + state = AccountRiskState(portfolio_value=0.0, daily_pnl=-1.0) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=1000, quantity=10, + ) + result = evaluate_order(order, config, state) + assert not result.passed + assert len(result.rejection_reasons) >= 2 + # Full decision trace is preserved + assert len(result.checks) > 0 + assert result.config_snapshot is not None + assert result.state_snapshot is not None + + +# --------------------------------------------------------------------------- +# 6. Fetch (ingestion path) fails closed +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestFetchFailsClosed: + """Broker fetch() for ingestion must return error result, not raise.""" + + async def test_fetch_connection_error_returns_error_result(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.get", side_effect=httpx.ConnectError("refused")): + result = await adapter.fetch("AAPL", {"endpoint": "positions"}) + + assert not result.ok + assert result.error is not None + assert result.items == [] + + async def test_fetch_timeout_returns_error_result(self): + adapter = _make_adapter() + + with patch("httpx.AsyncClient.get", side_effect=httpx.ReadTimeout("timeout")): + result = await adapter.fetch("AAPL", {"endpoint": "orders"}) + + assert not result.ok + assert result.error is not None + + async def test_fetch_http_429_returns_error_result(self): + adapter = _make_adapter() + + mock_resp = httpx.Response(429, text="Rate limited", request=httpx.Request("GET", "http://test")) + with patch("httpx.AsyncClient.get", side_effect=httpx.HTTPStatusError("429", response=mock_resp, request=mock_resp.request)): + result = await adapter.fetch("AAPL", {"endpoint": "positions"}) + + assert not result.ok + assert result.http_status == 429 diff --git a/tests/test_filings_adapter.py b/tests/test_filings_adapter.py new file mode 100644 index 0000000..df185fa --- /dev/null +++ b/tests/test_filings_adapter.py @@ -0,0 +1,177 @@ +"""Tests for the SEC EDGAR filings adapter. + +Validates request building, response parsing, and error handling. +""" +from services.adapters.filings_adapter import FilingsDataAdapter, SECEdgarAdapter + + +# --- Fake EDGAR EFTS responses --- + +EFTS_RESPONSE = { + "hits": { + "total": {"value": 3, "relation": "eq"}, + "hits": [ + { + "_id": "0001234567-26-000001", + "_source": { + "file_date": "2026-04-01", + "form_type": "8-K", + "entity_name": "Apple Inc.", + "file_num": "001-36743", + "period_of_report": "2026-03-31", + }, + }, + { + "_id": "0001234567-26-000002", + "_source": { + "file_date": "2026-03-15", + "form_type": "10-Q", + "entity_name": "Apple Inc.", + "file_num": "001-36743", + "period_of_report": "2026-03-15", + }, + }, + { + "_id": "0001234567-26-000003", + "_source": { + "file_date": "2026-01-30", + "form_type": "10-K", + "entity_name": "Apple Inc.", + "file_num": "001-36743", + "period_of_report": "2025-12-31", + }, + }, + ], + } +} + +EMPTY_EFTS_RESPONSE = { + "hits": { + "total": {"value": 0, "relation": "eq"}, + "hits": [], + } +} + + +class TestSECEdgarSourceType: + def test_source_type(self): + adapter = SECEdgarAdapter() + assert adapter.source_type() == "filings_api" + + def test_inherits_filings_data_adapter(self): + assert issubclass(SECEdgarAdapter, FilingsDataAdapter) + + def test_bucket_name(self): + adapter = SECEdgarAdapter() + assert adapter.bucket_name() == "stonks-raw-filings" + + +class TestSECEdgarBuildRequest: + def setup_method(self): + self.adapter = SECEdgarAdapter( + base_url="https://efts.sec.gov", + user_agent="TestAgent/1.0", + ) + + def test_default_request(self): + url, params, headers = self.adapter._build_request("AAPL", {}) + assert url == "https://efts.sec.gov/LATEST/search-index" + assert params["q"] == '"AAPL"' + assert params["forms"] == "8-K,10-Q,10-K" + assert headers["User-Agent"] == "TestAgent/1.0" + + def test_custom_forms(self): + _, params, _ = self.adapter._build_request("AAPL", {"forms": "8-K"}) + assert params["forms"] == "8-K" + + def test_date_range(self): + config = {"start_date": "2026-01-01", "end_date": "2026-04-10"} + _, params, _ = self.adapter._build_request("AAPL", config) + assert params["dateRange"] == "custom" + assert params["startdt"] == "2026-01-01" + assert params["enddt"] == "2026-04-10" + + def test_cik_filter(self): + _, params, _ = self.adapter._build_request("AAPL", {"cik": "0000320193"}) + assert "cik:0000320193" in params["q"] + assert '"AAPL"' in params["q"] + + def test_custom_query_override(self): + _, params, _ = self.adapter._build_request("AAPL", {"query": "apple AND revenue"}) + assert params["q"] == "apple AND revenue" + + def test_trailing_slash_stripped(self): + adapter = SECEdgarAdapter(base_url="https://efts.sec.gov/") + url, _, _ = adapter._build_request("AAPL", {}) + assert "//LATEST" not in url + + def test_no_date_params_when_absent(self): + _, params, _ = self.adapter._build_request("AAPL", {}) + assert "dateRange" not in params + assert "startdt" not in params + assert "enddt" not in params + + +class TestSECEdgarExtractItems: + def setup_method(self): + self.adapter = SECEdgarAdapter() + + def test_extract_filings(self): + items = self.adapter._extract_items(EFTS_RESPONSE) + assert len(items) == 3 + assert items[0]["_id"] == "0001234567-26-000001" + assert items[0]["_source"]["form_type"] == "8-K" + + def test_extract_empty_results(self): + items = self.adapter._extract_items(EMPTY_EFTS_RESPONSE) + assert items == [] + + def test_extract_missing_hits_key(self): + items = self.adapter._extract_items({"status": "OK"}) + assert items == [] + + def test_extract_non_dict_hits(self): + items = self.adapter._extract_items({"hits": "unexpected"}) + assert items == [] + + def test_extract_non_list_inner_hits(self): + items = self.adapter._extract_items({"hits": {"hits": "bad"}}) + assert items == [] + + +class TestSECEdgarTotalHits: + def setup_method(self): + self.adapter = SECEdgarAdapter() + + def test_total_hits_dict(self): + assert self.adapter._total_hits(EFTS_RESPONSE) == 3 + + def test_total_hits_int(self): + data = {"hits": {"total": 5, "hits": []}} + assert self.adapter._total_hits(data) == 5 + + def test_total_hits_missing(self): + assert self.adapter._total_hits({}) == 0 + + def test_total_hits_non_dict_hits(self): + assert self.adapter._total_hits({"hits": "bad"}) == 0 + + +class TestSECEdgarErrorResult: + def test_error_result_fields(self): + adapter = SECEdgarAdapter() + result = adapter._error_result("AAPL", "rate limited", 150.0, http_status=429, raw=b"slow down") + assert not result.ok + assert result.error == "rate limited" + assert result.http_status == 429 + assert result.response_time_ms == 150.0 + assert result.raw_payload == b"slow down" + assert result.metadata["provider"] == "sec_edgar" + assert result.source_type == "filings_api" + + def test_error_result_defaults(self): + adapter = SECEdgarAdapter() + result = adapter._error_result("MSFT", "timeout", 200.0) + assert result.http_status is None + assert result.raw_payload == b"" + assert result.ticker == "MSFT" diff --git a/tests/test_html_parser.py b/tests/test_html_parser.py new file mode 100644 index 0000000..01c86ca --- /dev/null +++ b/tests/test_html_parser.py @@ -0,0 +1,582 @@ +"""Tests for the HTML-to-text parsing pipeline. + +Validates body extraction, metadata extraction, boilerplate removal, +quality scoring, link extraction, document type inference, and company +mention detection. + +Requirements: 4.1, 4.2, 4.3 +""" +from services.parser.html_parser import ( + CompanyMention, + ParsedDocument, + QualitySignals, + _block_score, + _collapse_whitespace, + _detect_repeated_blocks, + _link_density, + _remove_short_orphan_lines, + _text_density, + detect_company_mentions, + extract_body_text, + extract_metadata, + extract_outbound_links, + infer_document_type, + parse_html, + score_parse_quality, + score_quality, +) + +RICH_HTML = """<!DOCTYPE html> +<html lang="en"> +<head> + <title>Apple Q2 Earnings Beat Expectations + + + + + + + + + + +
+

Apple Q2 Earnings Beat Expectations

+

Apple Inc. reported quarterly revenue of $95 billion, exceeding analyst estimates. + The company saw strong growth in its services division and iPhone sales across all + major markets worldwide. Revenue from the App Store and iCloud subscriptions + continued to climb, contributing significantly to the overall results.

+

CEO Tim Cook highlighted the company's commitment to innovation and expanding + its ecosystem. The services segment alone generated over $20 billion in revenue, + marking a new quarterly record for the division.

+ External analysis + Related article +
+
Copyright 2026 TechNews. All rights reserved. Privacy policy applies.
+ + + +""" + +MINIMAL_HTML = "

Short.

" + +BOILERPLATE_HTML = """ + +
+

The actual article content is here with enough words to pass quality checks. + This paragraph discusses important market developments and financial results + that are relevant to investors and analysts tracking the technology sector.

+
+ + +
Copyright © 2026. All rights reserved. Terms of service apply.
+""" + + +class TestExtractBodyText: + def test_extracts_article_content(self): + text = extract_body_text(RICH_HTML) + assert "Apple Inc. reported quarterly revenue" in text + assert "strong growth" in text + + def test_strips_nav_footer_sidebar(self): + text = extract_body_text(RICH_HTML) + assert "Navigation links here" not in text + assert "Sidebar content" not in text + + def test_strips_boilerplate_text(self): + text = extract_body_text(BOILERPLATE_HTML) + assert "Subscribe to our newsletter" not in text + assert "Copyright ©" not in text + + def test_finds_article_body_class(self): + text = extract_body_text(BOILERPLATE_HTML) + assert "actual article content" in text + + def test_minimal_html_returns_text(self): + text = extract_body_text(MINIMAL_HTML) + assert "Short." in text + + def test_strips_script_and_style(self): + html = "

Real content here

" + text = extract_body_text(html) + assert "alert" not in text + assert "color:red" not in text + assert "Real content here" in text + + def test_empty_html(self): + text = extract_body_text("") + assert text == "" + + +class TestExtractMetadata: + def test_extracts_title(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["title"] == "Apple Q2 Earnings Beat" + + def test_extracts_author(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["author"] == "Jane Reporter" + + def test_extracts_publisher(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["publisher"] == "TechNews" + + def test_extracts_published_at(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["published_at"] == "2026-04-10T14:00:00Z" + + def test_extracts_canonical_url(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["canonical_url"] == "https://technews.example.com/apple-q2-earnings" + + def test_extracts_language(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["language"] == "en" + + def test_extracts_keywords(self): + meta = extract_metadata(RICH_HTML, "https://technews.example.com/article") + assert meta["tags"] is not None + assert "apple" in str(meta["tags"]) + + def test_fallback_publisher_from_hostname(self): + meta = extract_metadata(MINIMAL_HTML, "https://example.com/page") + assert meta["publisher"] == "example.com" + + def test_no_url_publisher_empty(self): + meta = extract_metadata(MINIMAL_HTML, "") + assert meta["publisher"] == "" + + +class TestExtractOutboundLinks: + def test_finds_external_links(self): + links = extract_outbound_links(RICH_HTML, "https://technews.example.com/article") + assert "https://other-site.com/analysis" in links + + def test_excludes_same_host_links(self): + links = extract_outbound_links(RICH_HTML, "https://technews.example.com/article") + assert all("technews.example.com" not in link for link in links) + + def test_deduplicates_links(self): + html = '12' + links = extract_outbound_links(html, "https://example.com") + assert links.count("https://ext.com/a") == 1 + + def test_ignores_fragment_and_javascript(self): + html = 'topjs' + links = extract_outbound_links(html, "https://example.com") + assert links == [] + + +class TestScoreQuality: + def test_very_short_text_low(self): + score, conf = score_quality("hello world") + assert score < 0.5 + # With default body_found=True, very short text lands in medium + assert conf in ("low", "medium") + + def test_medium_text(self): + words = [f"word{i}" for i in range(100)] + text = " ".join(words) + "." + score, conf = score_quality(text) + # 100 diverse words with sentence structure scores well + assert conf in ("medium", "high") + + def test_long_diverse_text_high(self): + words = [f"word{i}" for i in range(300)] + text = ". ".join(" ".join(words[i:i+10]) for i in range(0, 300, 10)) + "." + score, conf = score_quality(text) + assert conf == "high" + assert score >= 0.65 + + def test_empty_text_low(self): + score, conf = score_quality("") + assert conf == "low" + assert score < 0.35 + + +class TestScoreParseQuality: + """Tests for the multi-signal quality scoring function.""" + + def test_returns_four_tuple(self): + score, conf, signals, warnings = score_parse_quality("hello world") + assert isinstance(score, float) + assert conf in ("low", "medium", "high") + assert isinstance(signals, QualitySignals) + assert isinstance(warnings, list) + + def test_empty_text_is_low(self): + score, conf, signals, warnings = score_parse_quality("") + assert conf == "low" + assert "very_short_text" in warnings + + def test_short_text_warns(self): + text = " ".join(["word"] * 30) + _score, _conf, _signals, warnings = score_parse_quality(text) + assert "short_text" in warnings + + def test_body_not_found_warns(self): + text = " ".join([f"word{i}" for i in range(100)]) + "." + _score, _conf, signals, warnings = score_parse_quality(text, body_found=False) + assert "no_article_body_found" in warnings + assert signals.body_found_signal < 0.5 + + def test_metadata_boosts_score(self): + text = ". ".join(" ".join(f"word{i}" for i in range(j, j+10)) for j in range(0, 200, 10)) + "." + score_no_meta, _, _, _ = score_parse_quality(text) + score_with_meta, _, _, _ = score_parse_quality( + text, has_title=True, has_author=True, has_publisher=True, has_published_at=True, + ) + assert score_with_meta > score_no_meta + + def test_signals_as_dict(self): + _, _, signals, _ = score_parse_quality("hello world") + d = signals.as_dict() + assert "word_count" in d + assert "diversity" in d + assert "body_found" in d + + def test_well_structured_article_scores_high(self): + paragraphs = [] + for i in range(5): + sentences = ". ".join(f"Sentence {j} of paragraph {i} with diverse vocabulary" for j in range(4)) + paragraphs.append(sentences + ".") + text = "\n\n".join(paragraphs) + score, conf, signals, warnings = score_parse_quality( + text, body_found=True, has_title=True, has_author=True, + has_publisher=True, has_published_at=True, + ) + assert conf == "high" + assert score >= 0.7 + assert signals.paragraph_signal == 1.0 + assert signals.body_found_signal == 1.0 + + +class TestInferDocumentType: + def test_filing_from_url(self): + assert infer_document_type("", "https://sec.gov/filing/10-k") == "filing" + + def test_transcript_from_url(self): + assert infer_document_type("", "https://example.com/earnings-call-transcript") == "transcript" + + def test_press_release_from_url(self): + assert infer_document_type("", "https://example.com/press-release/q2") == "press_release" + + def test_default_article(self): + assert infer_document_type("", "https://example.com/news/story") == "article" + + +class TestDetectCompanyMentions: + def test_detects_ticker(self): + aliases = [{"company_id": "1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}] + mentions = detect_company_mentions("Shares of AAPL rose 5% today", aliases) + assert len(mentions) == 1 + assert mentions[0]["ticker"] == "AAPL" + assert mentions[0]["confidence"] == 0.9 # ticker confidence + + def test_detects_company_name(self): + aliases = [{"company_id": "1", "alias": "Apple Inc.", "alias_type": "legal_name", "ticker": "AAPL"}] + mentions = detect_company_mentions("Apple Inc. reported strong earnings", aliases) + assert len(mentions) == 1 + assert mentions[0]["confidence"] == 0.85 # legal_name confidence + + def test_no_false_positive_short_ticker(self): + aliases = [{"company_id": "1", "alias": "A", "alias_type": "ticker", "ticker": "A"}] + mentions = detect_company_mentions("This is a sentence about nothing", aliases) + assert len(mentions) == 0 + + def test_short_ticker_case_sensitive(self): + aliases = [{"company_id": "1", "alias": "AI", "alias_type": "ticker", "ticker": "AI"}] + # "AI" as a word should match case-sensitively + mentions = detect_company_mentions("The AI revolution is here", aliases) + assert len(mentions) == 1 + # Lowercase "ai" should not match + mentions2 = detect_company_mentions("the ai revolution is here", aliases) + assert len(mentions2) == 0 + + def test_deduplicates_by_company(self): + aliases = [ + {"company_id": "1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}, + {"company_id": "1", "alias": "Apple Inc.", "alias_type": "legal_name", "ticker": "AAPL"}, + ] + mentions = detect_company_mentions("AAPL Apple Inc. reported earnings", aliases) + assert len(mentions) == 1 + # Should keep the higher confidence (ticker=0.9 > legal_name=0.85) + assert mentions[0]["confidence"] == 0.9 + + def test_match_count_accumulated(self): + aliases = [{"company_id": "1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}] + mentions = detect_company_mentions("AAPL beat estimates. AAPL shares rose.", aliases) + assert len(mentions) == 1 + assert mentions[0]["match_count"] == 2 + + def test_multiple_companies(self): + aliases = [ + {"company_id": "1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}, + {"company_id": "2", "alias": "MSFT", "alias_type": "ticker", "ticker": "MSFT"}, + ] + mentions = detect_company_mentions("AAPL and MSFT both reported earnings", aliases) + assert len(mentions) == 2 + tickers = {m["ticker"] for m in mentions} + assert tickers == {"AAPL", "MSFT"} + + def test_brand_alias(self): + aliases = [{"company_id": "1", "alias": "iPhone", "alias_type": "brand", "ticker": "AAPL"}] + mentions = detect_company_mentions("The new iPhone sales exceeded expectations", aliases) + assert len(mentions) == 1 + assert mentions[0]["confidence"] == 0.6 # brand confidence + + def test_empty_text(self): + aliases = [{"company_id": "1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}] + assert detect_company_mentions("", aliases) == [] + + def test_empty_aliases(self): + assert detect_company_mentions("Some text about stocks", []) == [] + + def test_case_insensitive_name_match(self): + aliases = [{"company_id": "1", "alias": "Apple Inc.", "alias_type": "legal_name", "ticker": "AAPL"}] + mentions = detect_company_mentions("APPLE INC. reported earnings", aliases) + assert len(mentions) == 1 + + +class TestParseHtml: + def test_returns_parsed_document(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert isinstance(result, ParsedDocument) + + def test_body_text_populated(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert "Apple Inc." in result.body_text + assert result.word_count > 0 + + def test_metadata_populated(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert result.title == "Apple Q2 Earnings Beat" + assert result.author == "Jane Reporter" + assert result.publisher == "TechNews" + + def test_quality_scoring(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert result.quality_score > 0 + assert result.confidence in ("low", "medium", "high") + + def test_quality_signals_populated(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert isinstance(result.quality_signals, QualitySignals) + assert result.quality_signals.body_found_signal == 1.0 + assert result.quality_signals.metadata_signal > 0 + + def test_low_quality_flag_on_minimal(self): + result = parse_html(MINIMAL_HTML, "") + assert result.low_quality_flag is True + assert result.confidence == "low" + + def test_rich_html_not_low_quality(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert result.low_quality_flag is False + + def test_quality_warnings_list(self): + result = parse_html(MINIMAL_HTML, "") + assert isinstance(result.quality_warnings, list) + + def test_tags_extracted(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert "apple" in result.tags + + def test_document_type_inferred(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert result.document_type == "article" + + def test_outbound_links(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert any("other-site.com" in link for link in result.outbound_links) + + def test_mentioned_companies_with_aliases(self): + aliases = [ + {"company_id": "1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}, + {"company_id": "1", "alias": "Apple Inc.", "alias_type": "legal_name", "ticker": "AAPL"}, + ] + result = parse_html(RICH_HTML, "https://technews.example.com/article", aliases=aliases) + assert len(result.mentioned_companies) == 1 + assert result.mentioned_companies[0].ticker == "AAPL" + assert isinstance(result.mentioned_companies[0], CompanyMention) + + def test_no_mentions_without_aliases(self): + result = parse_html(RICH_HTML, "https://technews.example.com/article") + assert result.mentioned_companies == [] + +# --- HTML fixtures for boilerplate reduction tests --- + +NO_SEMANTIC_HTML = """ + +
+

The Federal Reserve announced a 25 basis point rate cut on Wednesday, + surprising markets that had expected rates to remain unchanged. Bond yields + fell sharply across the curve, with the 10-year Treasury dropping to 3.8 percent. + Equity markets rallied on the news, with the S&P 500 gaining 1.2 percent by close.

+

Analysts noted that the decision reflects growing concerns about slowing economic + growth and weakening labor market data. Several Fed governors had signaled openness + to easing in recent speeches, but the timing caught many off guard.

+

Market participants are now pricing in additional cuts at the next two meetings, + with futures indicating a 70 percent probability of another reduction in September.

+
+ +""" + +HEAVY_BOILERPLATE_HTML = """ + + + +
+

Tesla reported record deliveries in Q1 2026, shipping over 500,000 vehicles + globally. The company attributed the strong performance to expanded production + capacity at its Berlin and Austin gigafactories, as well as growing demand for + the refreshed Model Y across European and Asian markets.

+

Revenue for the quarter came in at $28 billion, beating consensus estimates + by roughly 4 percent. Automotive gross margins improved to 19.5 percent, + reversing a trend of compression seen throughout 2025.

+
+ + +
Sponsored content here
+
Copyright © 2026 FinanceDaily. All rights reserved. Terms of service. Privacy policy.
+""" + +REPEATED_BLOCKS_HTML = """ +
+

Apple announced a new partnership with Samsung to develop next-generation + display technology for future iPhone models. The collaboration is expected to + yield OLED panels with improved brightness and energy efficiency.

+

This is a developing story. Check back for updates as more information becomes available.

+

Industry analysts view the partnership as a strategic move to secure supply + chain advantages ahead of the 2027 product cycle. Display costs represent a + significant portion of iPhone bill of materials.

+

This is a developing story. Check back for updates as more information becomes available.

+
+""" + + +class TestTextDensityScoring: + """Tests for text-density-based block scoring heuristics.""" + + def test_content_rich_div_has_high_density(self): + from bs4 import BeautifulSoup + html = "

This is a substantial paragraph with real content about markets.

" + soup = BeautifulSoup(html, "html.parser") + tag = soup.find("div") + assert _text_density(tag) > _MIN_TEXT_DENSITY + + def test_link_heavy_div_has_high_link_density(self): + from bs4 import BeautifulSoup + html = '
Link one Link two Link three
' + soup = BeautifulSoup(html, "html.parser") + tag = soup.find("div") + assert _link_density(tag) > 0.8 + + def test_article_div_has_low_link_density(self): + from bs4 import BeautifulSoup + html = "

A long paragraph of article text that discusses important financial results and market movements in detail.

" + soup = BeautifulSoup(html, "html.parser") + tag = soup.find("div") + assert _link_density(tag) < 0.1 + + def test_block_score_prefers_content_over_nav(self): + from bs4 import BeautifulSoup + content_html = "
" + "

Substantial article paragraph with real content about markets and earnings.

" * 3 + "
" + nav_html = '
LinkLinkLinkLink
' + soup_c = BeautifulSoup(content_html, "html.parser") + soup_n = BeautifulSoup(nav_html, "html.parser") + assert _block_score(soup_c.find("div")) > _block_score(soup_n.find("div")) + + +class TestBoilerplateReduction: + """Tests for enhanced boilerplate reduction pipeline.""" + + def test_strips_cookie_banner(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "cookie" not in text.lower() + + def test_strips_signup_form(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "Sign up for free" not in text + + def test_strips_social_share(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "Share this article" not in text + + def test_strips_ad_container(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "Sponsored content" not in text + + def test_strips_related_posts(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "You may also like" not in text + + def test_preserves_article_content(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "Tesla reported record deliveries" in text + assert "Revenue for the quarter" in text + + def test_strips_copyright_footer(self): + text = extract_body_text(HEAVY_BOILERPLATE_HTML) + assert "Copyright ©" not in text + + +class TestBodyExtractionFallback: + """Tests for text-density fallback when no semantic selector matches.""" + + def test_finds_content_without_article_tag(self): + text = extract_body_text(NO_SEMANTIC_HTML) + assert "Federal Reserve announced" in text + assert "25 basis point rate cut" in text + + def test_prefers_content_over_nav_links(self): + text = extract_body_text(NO_SEMANTIC_HTML) + # The nav-like link list should not dominate the output + assert "Story 1" not in text or "Federal Reserve" in text + + +class TestRepeatedBlockDetection: + """Tests for repeated/template text detection.""" + + def test_collapses_repeated_template_text(self): + text = extract_body_text(REPEATED_BLOCKS_HTML) + count = text.count("This is a developing story") + assert count <= 1 + + def test_preserves_unique_content(self): + text = extract_body_text(REPEATED_BLOCKS_HTML) + assert "Apple announced a new partnership" in text + assert "Industry analysts view" in text + + +class TestOrphanLineRemoval: + """Tests for short orphan line removal.""" + + def test_removes_short_fragments(self): + text = _remove_short_orphan_lines("OK\nThis is a real sentence about markets.\nHi") + assert "OK" not in text + assert "Hi" not in text + assert "real sentence" in text + + def test_keeps_short_lines_with_punctuation(self): + text = _remove_short_orphan_lines("Breaking news.\nDetails follow in the article.") + assert "Breaking news." in text + + +class TestCollapseWhitespace: + """Tests for whitespace collapsing.""" + + def test_collapses_multiple_blank_lines(self): + text = _collapse_whitespace("Line one.\n\n\n\nLine two.") + assert "\n\n\n" not in text + assert "Line one." in text + assert "Line two." in text + + def test_strips_leading_trailing(self): + text = _collapse_whitespace("\n\n Hello world. \n\n") + assert text == "Hello world." + + +# Import the constant for use in density tests +from services.parser.html_parser import _MIN_TEXT_DENSITY diff --git a/tests/test_iceberg.py b/tests/test_iceberg.py new file mode 100644 index 0000000..a662416 --- /dev/null +++ b/tests/test_iceberg.py @@ -0,0 +1,161 @@ +"""Tests for Iceberg table creation and metadata management.""" +from datetime import date + +import pyarrow as pa + +from services.lake_publisher.iceberg import ( + ICEBERG_CATALOG, + ICEBERG_SCHEMA, + TABLE_SCHEMAS, + IcebergManager, + IcebergTableDef, + _arrow_type_to_trino, + get_all_table_defs, + get_table_def, +) +from services.lake_publisher.partitions import TABLE_PARTITIONS, PartitionSpec + + +# --------------------------------------------------------------------------- +# _arrow_type_to_trino +# --------------------------------------------------------------------------- + + +def test_arrow_to_trino_string(): + assert _arrow_type_to_trino(pa.string()) == "VARCHAR" + + +def test_arrow_to_trino_float64(): + assert _arrow_type_to_trino(pa.float64()) == "DOUBLE" + + +def test_arrow_to_trino_int64(): + assert _arrow_type_to_trino(pa.int64()) == "BIGINT" + + +def test_arrow_to_trino_int32(): + assert _arrow_type_to_trino(pa.int32()) == "INTEGER" + + +def test_arrow_to_trino_bool(): + assert _arrow_type_to_trino(pa.bool_()) == "BOOLEAN" + + +def test_arrow_to_trino_date32(): + assert _arrow_type_to_trino(pa.date32()) == "DATE" + + +def test_arrow_to_trino_timestamp_utc(): + assert _arrow_type_to_trino(pa.timestamp("us", tz="UTC")) == "TIMESTAMP(6) WITH TIME ZONE" + + +def test_arrow_to_trino_timestamp_no_tz(): + assert _arrow_type_to_trino(pa.timestamp("us")) == "TIMESTAMP(6)" + + +# --------------------------------------------------------------------------- +# TABLE_SCHEMAS registry +# --------------------------------------------------------------------------- + + +def test_table_schemas_covers_all_partitions(): + """Every table in TABLE_PARTITIONS should have a corresponding PyArrow schema.""" + for table_name in TABLE_PARTITIONS: + assert table_name in TABLE_SCHEMAS, f"Missing schema for {table_name}" + + +# --------------------------------------------------------------------------- +# IcebergTableDef +# --------------------------------------------------------------------------- + + +def test_table_def_qualified_name(): + td = get_table_def("trade_signals") + assert td.qualified_name == f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.trade_signals" + + +def test_table_def_location(): + td = get_table_def("trade_signals") + assert td.location == "s3a://stonks-lakehouse/warehouse/trade_signals/" + + +def test_table_def_column_defs(): + td = get_table_def("trade_signals") + cols = td.column_defs_sql() + col_names = [c.strip().split()[0] for c in cols] + assert "signal_id" in col_names + assert "ticker" in col_names + assert "dt" in col_names + + +def test_table_def_partition_keys_dt_only(): + td = get_table_def("trade_signals") + part = td.partition_keys_sql() + assert "'dt'" in part + assert "model_version" not in part + + +def test_table_def_partition_keys_with_extra(): + td = get_table_def("document_extractions") + part = td.partition_keys_sql() + assert "'dt'" in part + assert "'model_version'" in part + + +def test_create_table_sql_structure(): + td = get_table_def("market_bars") + sql = td.create_table_sql() + assert "CREATE TABLE IF NOT EXISTS" in sql + assert f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.market_bars" in sql + assert "format = 'PARQUET'" in sql + assert "s3a://stonks-lakehouse/warehouse/market_bars/" in sql + assert "partitioning" in sql + + +def test_create_table_sql_columns_match_schema(): + td = get_table_def("market_bars") + sql = td.create_table_sql() + # All columns from the PyArrow schema should appear + for i in range(len(td.schema)): + col_name = td.schema.field(i).name + assert col_name in sql, f"Column {col_name} missing from DDL" + + +# --------------------------------------------------------------------------- +# get_all_table_defs / get_table_def +# --------------------------------------------------------------------------- + + +def test_get_all_table_defs_count(): + defs = get_all_table_defs() + assert len(defs) == len(TABLE_PARTITIONS) + + +def test_get_table_def_unknown_raises(): + try: + get_table_def("nonexistent_table") + assert False, "Should have raised ValueError" + except ValueError: + pass + + +def test_get_all_table_defs_all_generate_valid_sql(): + """Every table def should produce syntactically reasonable DDL.""" + for td in get_all_table_defs(): + sql = td.create_table_sql() + assert "CREATE TABLE IF NOT EXISTS" in sql + assert td.table_name in sql + assert "PARQUET" in sql + + +# --------------------------------------------------------------------------- +# IcebergManager (unit tests with no Trino connection) +# --------------------------------------------------------------------------- + + +def test_iceberg_manager_defaults(): + mgr = IcebergManager() + assert mgr.host == "localhost" + assert mgr.port == 8080 + assert mgr.catalog == ICEBERG_CATALOG + assert mgr.schema == ICEBERG_SCHEMA diff --git a/tests/test_integration_ingest_to_recommendation.py b/tests/test_integration_ingest_to_recommendation.py new file mode 100644 index 0000000..32c8f32 --- /dev/null +++ b/tests/test_integration_ingest_to_recommendation.py @@ -0,0 +1,648 @@ +"""Integration tests for the full ingest-to-recommendation flow. + +Exercises the pipeline end-to-end through all stages: + Ingestion → Parsing → Extraction → Aggregation → Recommendation + +Each stage uses the real logic functions from the service modules. +External infrastructure (PostgreSQL, MinIO, Redis, Ollama) is replaced +with lightweight fakes that preserve the data contracts between stages. + +Requirements: 3.1-3.4, 4.1-4.3, 5.1-5.5, 6.1-6.5, 7.1-7.4 +""" +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.aggregation.worker import ( + ImpactRow, + assemble_trend_with_evidence, + build_weighted_signals, +) +from services.extractor.client import ExtractionAttempt, ExtractionResponse +from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction +from services.extractor.worker import persist_extraction +from services.parser.html_parser import ParsedDocument, detect_company_mentions, parse_html +from services.parser.worker import build_parser_output_json +from services.recommendation.eligibility import EligibilityConfig, evaluate_eligibility +from services.recommendation.suppression import ( + DataQualityContext, + SuppressionConfig, + evaluate_suppression, +) +from services.recommendation.worker import ( + build_recommendation, + build_thesis, + classify_risk, +) +from services.shared.schemas import ( + ActionType, + RecommendationMode, + TrendDirection, + TrendWindow, +) + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + +# --------------------------------------------------------------------------- +# Shared test fixtures +# --------------------------------------------------------------------------- + +SAMPLE_HTML = """ + +Apple Reports Record Q2 Earnings + + + + + +
+

Apple Reports Record Q2 Earnings

+

Apple Inc. (AAPL) reported record quarterly revenue of $120 billion, +beating analyst expectations by 8%. CEO Tim Cook cited strong iPhone and +services growth as key drivers.

+

The company also announced a $100 billion share buyback program, +signaling confidence in future cash flows. Analysts at Goldman Sachs +raised their price target to $250.

+

However, regulatory scrutiny in the EU remains a risk factor, +with potential fines related to the Digital Markets Act.

+
+
Copyright 2026
+ + +""" + +SAMPLE_EXTRACTION_JSON = { + "summary": "Apple reported record Q2 revenue of $120B, beating expectations by 8%. " + "Announced $100B buyback. EU regulatory risk remains.", + "companies": [ + { + "ticker": "AAPL", + "company_name": "Apple Inc.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.75, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": [ + "Record quarterly revenue of $120 billion", + "$100 billion share buyback announced", + "Goldman Sachs raised price target to $250", + ], + "risks": ["EU regulatory scrutiny under Digital Markets Act"], + "evidence_spans": [ + "Apple Inc. (AAPL) reported record quarterly revenue of $120 billion", + "beating analyst expectations by 8%", + "announced a $100 billion share buyback program", + ], + } + ], + "macro_themes": ["consumer_tech", "buybacks"], + "novelty_score": 0.7, + "confidence": 0.88, + "extraction_warnings": [], +} + +COMPANY_ALIASES = [ + {"company_id": "comp-1", "alias": "AAPL", "alias_type": "ticker", "ticker": "AAPL"}, + {"company_id": "comp-1", "alias": "Apple Inc.", "alias_type": "legal_name", "ticker": "AAPL"}, +] + + +# --------------------------------------------------------------------------- +# Stage 1: Parsing +# --------------------------------------------------------------------------- + + +class TestParsingStage: + """Verify the HTML parsing pipeline produces structured output.""" + + def test_parse_html_extracts_body_text(self): + parsed = parse_html(SAMPLE_HTML, "https://example.com/apple-earnings") + assert parsed.body_text is not None + assert "record quarterly revenue" in parsed.body_text.lower() + # Boilerplate should be stripped + assert "Site Navigation" not in parsed.body_text + assert "Copyright" not in parsed.body_text + + def test_parse_html_extracts_metadata(self): + parsed = parse_html(SAMPLE_HTML, "https://example.com/apple-earnings") + assert parsed.title == "Apple Reports Record Q2 Earnings" + assert parsed.quality_score > 0.0 + assert parsed.confidence != "low" + + def test_detect_company_mentions_finds_aapl(self): + parsed = parse_html(SAMPLE_HTML, "https://example.com/apple-earnings") + mentions = detect_company_mentions(parsed.body_text, COMPANY_ALIASES) + tickers_found = {m["ticker"] for m in mentions} + assert "AAPL" in tickers_found + + def test_parser_output_json_structure(self): + parsed = parse_html(SAMPLE_HTML, "https://example.com/apple-earnings") + mentions = detect_company_mentions(parsed.body_text, COMPANY_ALIASES) + output = build_parser_output_json(parsed, mentions) + assert "quality_score" in output + assert "mentioned_companies" in output + assert isinstance(output["mentioned_companies"], list) + assert output["title"] == "Apple Reports Record Q2 Earnings" + + +# --------------------------------------------------------------------------- +# Stage 2: Extraction validation +# --------------------------------------------------------------------------- + + +class TestExtractionStage: + """Verify extraction schema validation and result construction.""" + + def test_validate_extraction_accepts_valid_json(self): + report = validate_extraction(SAMPLE_EXTRACTION_JSON) + assert report.valid + assert report.parsed is not None + assert report.parsed.companies[0].ticker == "AAPL" + + def test_validate_extraction_rejects_invalid_json(self): + report = validate_extraction("not json at all") + assert not report.valid + assert len(report.errors) > 0 + + def test_validate_extraction_rejects_bad_schema(self): + bad = {"summary": "test"} # missing required fields + report = validate_extraction(bad) + assert not report.valid + + def test_extraction_result_matches_intelligence_schema(self): + result = ExtractionResult.model_validate(SAMPLE_EXTRACTION_JSON) + assert result.confidence == 0.88 + assert len(result.companies) == 1 + assert result.companies[0].catalyst_type.value == "earnings" + assert result.novelty_score == 0.7 + + def test_validate_extraction_with_document_text_checks_evidence(self): + """Evidence grounding check should warn if spans not found.""" + report = validate_extraction( + SAMPLE_EXTRACTION_JSON, + document_text="Completely unrelated text about weather.", + ) + # Should still be valid (evidence grounding is a warning, not error) + assert report.valid + assert any("evidence_span_not_found" in w for w in report.warnings) + + +# --------------------------------------------------------------------------- +# Stage 3: Extraction persistence (mocked infra) +# --------------------------------------------------------------------------- + + +class TestExtractionPersistence: + """Verify extraction artifacts are persisted correctly.""" + + @pytest.mark.asyncio + async def test_persist_successful_extraction_creates_all_artifacts(self): + result_obj = ExtractionResult.model_validate(SAMPLE_EXTRACTION_JSON) + validation = ValidationReport(valid=True, errors=[], warnings=[], parsed=result_obj) + attempt = ExtractionAttempt( + raw_output=json.dumps(SAMPLE_EXTRACTION_JSON), + validation=validation, + error=None, + duration_ms=450, + model="test-model", + ) + response = ExtractionResponse( + success=True, + result=result_obj, + attempts=[attempt], + prompt_metadata={"prompt_version": "document-intel-v2", "schema_version": "2.0.0"}, + model="test-model", + total_duration_ms=450, + ) + + pool = AsyncMock() + pool.fetchval = AsyncMock(side_effect=["intel-1", "impact-1", "metrics-1"]) + pool.execute = AsyncMock() + minio = MagicMock() + + persist_result = await persist_extraction( + pool=pool, + minio_client=minio, + document_id=str(uuid.uuid4()), + ticker="AAPL", + extraction_response=response, + company_id_map={"AAPL": "comp-1"}, + source_credibility=0.8, + timestamp=NOW, + ) + + assert persist_result.success + assert persist_result.intelligence_id == "intel-1" + assert persist_result.impact_ids == ["impact-1"] + # 4 MinIO uploads: prompt, raw_output, validation, intelligence + assert minio.put_object.call_count == 4 + + +# --------------------------------------------------------------------------- +# Stage 4: Aggregation +# --------------------------------------------------------------------------- + + +class TestAggregationStage: + """Verify trend summary assembly from document impact records.""" + + def _make_impacts_from_extraction(self) -> list[ImpactRow]: + """Build ImpactRows that mirror what the extraction stage would produce.""" + return [ + ImpactRow( + document_id="doc-1", + confidence=0.88, + novelty_score=0.7, + source_credibility=0.8, + sentiment="positive", + impact_score=0.75, + catalyst_type="earnings", + key_facts=["Record revenue $120B", "$100B buyback"], + risks=["EU regulatory scrutiny"], + published_at=NOW - timedelta(hours=2), + ), + ImpactRow( + document_id="doc-2", + confidence=0.72, + novelty_score=0.5, + source_credibility=0.7, + sentiment="positive", + impact_score=0.6, + catalyst_type="rating_change", + key_facts=["Goldman raised target to $250"], + risks=[], + published_at=NOW - timedelta(hours=4), + ), + ImpactRow( + document_id="doc-3", + confidence=0.65, + novelty_score=0.4, + source_credibility=0.6, + sentiment="negative", + impact_score=0.4, + catalyst_type="legal", + key_facts=["EU DMA investigation"], + risks=["Potential fines"], + published_at=NOW - timedelta(hours=6), + ), + ] + + def test_aggregation_produces_bullish_trend(self): + impacts = self._make_impacts_from_extraction() + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "AAPL", "7d", signals, impacts, reference_time=NOW, + ) + summary = assembled.summary + + assert summary.entity_id == "AAPL" + assert summary.window == TrendWindow.SEVEN_DAY + # Two positive, one negative → should be bullish + assert summary.trend_direction == TrendDirection.BULLISH + assert summary.trend_strength > 0 + assert summary.confidence > 0 + assert len(summary.top_supporting_evidence) >= 1 + assert len(summary.top_opposing_evidence) >= 1 + assert summary.contradiction_score > 0 # has opposing signal + + def test_aggregation_evidence_rankings_are_populated(self): + impacts = self._make_impacts_from_extraction() + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "AAPL", "7d", signals, impacts, reference_time=NOW, + ) + + # Supporting evidence should include the positive docs + supporting_ids = {e.document_id for e in assembled.supporting_evidence} + assert "doc-1" in supporting_ids + assert "doc-2" in supporting_ids + + # Opposing evidence should include the negative doc + opposing_ids = {e.document_id for e in assembled.opposing_evidence} + assert "doc-3" in opposing_ids + + def test_aggregation_extracts_catalysts_and_risks(self): + impacts = self._make_impacts_from_extraction() + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "AAPL", "7d", signals, impacts, reference_time=NOW, + ) + summary = assembled.summary + + assert len(summary.dominant_catalysts) > 0 + assert "earnings" in summary.dominant_catalysts + assert len(summary.material_risks) > 0 + + +# --------------------------------------------------------------------------- +# Stage 5: Recommendation +# --------------------------------------------------------------------------- + + +class TestRecommendationStage: + """Verify recommendation generation from trend summaries.""" + + def _make_trend_from_aggregation(self): + """Build a TrendSummary that mirrors aggregation output.""" + impacts = [ + ImpactRow( + document_id="doc-1", confidence=0.88, novelty_score=0.7, + source_credibility=0.8, sentiment="positive", impact_score=0.75, + catalyst_type="earnings", key_facts=["Record revenue"], + risks=["EU regulatory"], published_at=NOW - timedelta(hours=2), + ), + ImpactRow( + document_id="doc-2", confidence=0.72, novelty_score=0.5, + source_credibility=0.7, sentiment="positive", impact_score=0.6, + catalyst_type="rating_change", key_facts=["Target raised"], + risks=[], published_at=NOW - timedelta(hours=4), + ), + ImpactRow( + document_id="doc-3", confidence=0.65, novelty_score=0.4, + source_credibility=0.6, sentiment="negative", impact_score=0.4, + catalyst_type="legal", key_facts=["DMA investigation"], + risks=["Potential fines"], published_at=NOW - timedelta(hours=6), + ), + ] + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "AAPL", "7d", signals, impacts, reference_time=NOW, + ) + return assembled.summary + + def test_eligibility_produces_buy_for_bullish_trend(self): + summary = self._make_trend_from_aggregation() + result = evaluate_eligibility(summary) + assert result.action == ActionType.BUY + assert result.eligible + + def test_recommendation_has_thesis_and_evidence(self): + summary = self._make_trend_from_aggregation() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result, reference_time=NOW) + + assert rec.ticker == "AAPL" + assert rec.action == ActionType.BUY + assert len(rec.thesis) > 0 + assert "[risk:" in rec.thesis + assert len(rec.evidence_refs) > 0 + assert rec.time_horizon == "swing_1d_10d" + + def test_recommendation_position_sizing_is_bounded(self): + summary = self._make_trend_from_aggregation() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result, reference_time=NOW) + + assert 0 < rec.position_sizing.portfolio_pct <= 0.05 + assert 0 < rec.position_sizing.max_loss_pct <= 0.01 + + def test_recommendation_mode_reflects_confidence(self): + summary = self._make_trend_from_aggregation() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result, reference_time=NOW) + + # With 3 impact records the aggregated confidence is moderate (~0.41), + # which is below the paper_confidence_threshold (0.50). The eligibility + # engine correctly assigns INFORMATIONAL mode for BUY actions with + # sub-threshold confidence. This validates Requirement 7.4. + if summary.confidence >= 0.50: + assert rec.mode in ( + RecommendationMode.PAPER_ELIGIBLE, + RecommendationMode.LIVE_ELIGIBLE, + ) + else: + assert rec.mode == RecommendationMode.INFORMATIONAL + + def test_suppression_blocks_low_quality_data(self): + summary = self._make_trend_from_aggregation() + low_quality_ctx = DataQualityContext( + total_documents=5, + valid_documents=1, + failed_documents=4, + avg_extraction_confidence=0.2, + newest_evidence_at=NOW - timedelta(days=14), + source_types=set(), + ) + suppression = evaluate_suppression( + summary, quality_ctx=low_quality_ctx, reference_time=NOW, + ) + assert suppression.suppressed + assert len(suppression.reasons) > 0 + + +# --------------------------------------------------------------------------- +# Full pipeline integration +# --------------------------------------------------------------------------- + + +class TestFullPipelineIntegration: + """End-to-end test wiring all stages together with real logic.""" + + def test_html_to_recommendation_pipeline(self): + """Walk a document through parse → validate extraction → aggregate → recommend.""" + + # --- Stage 1: Parse HTML --- + parsed = parse_html(SAMPLE_HTML, "https://example.com/apple-q2") + assert parsed.body_text + assert parsed.confidence != "low" + + mentions = detect_company_mentions(parsed.body_text, COMPANY_ALIASES) + assert any(m["ticker"] == "AAPL" for m in mentions) + + # --- Stage 2: Validate extraction output --- + report = validate_extraction( + SAMPLE_EXTRACTION_JSON, + document_text=parsed.body_text, + ) + assert report.valid + extraction = report.parsed + assert extraction is not None + assert extraction.companies[0].ticker == "AAPL" + + # --- Stage 3: Build impact records from extraction --- + company = extraction.companies[0] + impact = ImpactRow( + document_id="doc-pipeline-1", + confidence=extraction.confidence, + novelty_score=extraction.novelty_score, + source_credibility=0.8, + sentiment=company.sentiment.value, + impact_score=company.impact_score, + catalyst_type=company.catalyst_type.value, + key_facts=company.key_facts, + risks=company.risks, + published_at=NOW - timedelta(hours=1), + ) + + # Add a second supporting document for richer aggregation + impact2 = ImpactRow( + document_id="doc-pipeline-2", + confidence=0.75, + novelty_score=0.5, + source_credibility=0.7, + sentiment="positive", + impact_score=0.6, + catalyst_type="rating_change", + key_facts=["Analyst upgrade"], + risks=[], + published_at=NOW - timedelta(hours=3), + ) + + impacts = [impact, impact2] + + # --- Stage 4: Aggregate into trend summary --- + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "AAPL", "7d", signals, impacts, reference_time=NOW, + ) + summary = assembled.summary + + assert summary.trend_direction == TrendDirection.BULLISH + assert summary.confidence > 0 + assert len(summary.top_supporting_evidence) > 0 + + # --- Stage 5: Generate recommendation --- + eligibility = evaluate_eligibility(summary) + assert eligibility.action == ActionType.BUY + assert eligibility.eligible + + rec = build_recommendation(summary, eligibility, reference_time=NOW) + + # Final assertions: the recommendation is coherent end-to-end + assert rec.ticker == "AAPL" + assert rec.action == ActionType.BUY + assert rec.confidence == summary.confidence + assert len(rec.evidence_refs) > 0 + assert rec.thesis.startswith("[risk:") + assert "AAPL" in rec.thesis + assert "bullish" in rec.thesis + assert rec.time_horizon == "swing_1d_10d" + assert 0 < rec.position_sizing.portfolio_pct <= 0.05 + + def test_low_quality_document_is_blocked(self): + """A low-quality parse should not produce a trade-eligible recommendation.""" + + # Minimal HTML that produces a low-quality parse + bad_html = "

Ad. Subscribe now.

" + parsed = parse_html(bad_html, "https://example.com/junk") + + # Low quality parse → should not advance to extraction + # The parser worker checks confidence != "low" before enqueuing + if parsed.confidence == "low" or parsed.quality_score < 0.3: + # This is the expected path: document blocked at parse stage + return + + # If somehow it passes parsing, suppression should catch it + # Build a minimal trend with low data quality + from services.shared.schemas import TrendSummary + summary = TrendSummary( + entity_type="company", + entity_id="JUNK", + window=TrendWindow.SEVEN_DAY, + trend_direction=TrendDirection.BULLISH, + trend_strength=0.3, + confidence=0.3, + top_supporting_evidence=["doc-1"], + generated_at=NOW, + ) + suppression = evaluate_suppression(summary, reference_time=NOW) + # With only 1 evidence doc and low confidence, should be suppressed + assert suppression.suppressed + + def test_bearish_signal_produces_sell_recommendation(self): + """Negative sentiment documents should produce a SELL recommendation.""" + + impacts = [ + ImpactRow( + document_id="doc-bear-1", + confidence=0.82, + novelty_score=0.6, + source_credibility=0.8, + sentiment="negative", + impact_score=0.7, + catalyst_type="legal", + key_facts=["Major lawsuit filed"], + risks=["Potential $5B fine"], + published_at=NOW - timedelta(hours=1), + ), + ImpactRow( + document_id="doc-bear-2", + confidence=0.78, + novelty_score=0.5, + source_credibility=0.75, + sentiment="negative", + impact_score=0.65, + catalyst_type="earnings", + key_facts=["Revenue miss by 15%"], + risks=["Guidance lowered"], + published_at=NOW - timedelta(hours=3), + ), + ] + + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "TSLA", "7d", signals, impacts, reference_time=NOW, + ) + summary = assembled.summary + + assert summary.trend_direction == TrendDirection.BEARISH + + eligibility = evaluate_eligibility(summary) + assert eligibility.action == ActionType.SELL + + rec = build_recommendation(summary, eligibility, reference_time=NOW) + assert rec.ticker == "TSLA" + assert rec.action == ActionType.SELL + assert "SELL" in rec.thesis + + def test_contradictory_signals_produce_mixed_or_watch(self): + """Equal opposing signals should result in WATCH or MIXED direction.""" + + impacts = [ + ImpactRow( + document_id="doc-pos", + confidence=0.8, + novelty_score=0.5, + source_credibility=0.8, + sentiment="positive", + impact_score=0.6, + catalyst_type="earnings", + key_facts=["Beat expectations"], + risks=[], + published_at=NOW - timedelta(hours=1), + ), + ImpactRow( + document_id="doc-neg", + confidence=0.8, + novelty_score=0.5, + source_credibility=0.8, + sentiment="negative", + impact_score=0.6, + catalyst_type="legal", + key_facts=["Lawsuit filed"], + risks=["Regulatory risk"], + published_at=NOW - timedelta(hours=1), + ), + ] + + signals = build_weighted_signals(impacts, NOW, "7d") + assembled = assemble_trend_with_evidence( + "MSFT", "7d", signals, impacts, reference_time=NOW, + ) + summary = assembled.summary + + assert summary.trend_direction in (TrendDirection.MIXED, TrendDirection.NEUTRAL) + assert summary.contradiction_score > 0 + + eligibility = evaluate_eligibility(summary) + rec = build_recommendation(summary, eligibility, reference_time=NOW) + + # Contradictory signals → WATCH or HOLD, mode should be informational + assert rec.action in (ActionType.WATCH, ActionType.HOLD) + assert rec.mode == RecommendationMode.INFORMATIONAL diff --git a/tests/test_k8s_security.py b/tests/test_k8s_security.py new file mode 100644 index 0000000..e6d3692 --- /dev/null +++ b/tests/test_k8s_security.py @@ -0,0 +1,212 @@ +"""Tests for Kubernetes manifest security hardening. + +Validates that all deployments in infra/k8s/ follow security best practices: +- Scoped secrets (no monolithic stonks-secrets) +- Pod security contexts (runAsNonRoot, seccompProfile) +- Container security contexts (no privilege escalation, drop ALL caps) +- automountServiceAccountToken disabled +- Broker secrets only on trading-tier pods +""" +from __future__ import annotations + +import glob +from pathlib import Path + +import yaml + + +K8S_DIR = Path("infra/k8s") + +# Services that legitimately need broker secrets +BROKER_SECRET_ALLOWED = {"broker-adapter", "risk-engine"} + +# Services that legitimately need market-data secrets +MARKET_SECRET_ALLOWED = {"ingestion-worker"} + + +def _load_deployments() -> list[tuple[str, dict]]: + """Load all Deployment objects from infra/k8s/*.yaml.""" + deployments = [] + for path in sorted(K8S_DIR.glob("*.yaml")): + with open(path) as f: + for doc in yaml.safe_load_all(f): + if doc and doc.get("kind") == "Deployment": + name = doc["metadata"]["name"] + deployments.append((name, doc)) + return deployments + + +def _get_secret_refs(spec: dict) -> list[str]: + """Extract all secretRef names from a pod spec's envFrom.""" + refs = [] + for container in spec.get("containers", []): + for env_from in container.get("envFrom", []): + secret = env_from.get("secretRef", {}) + if secret.get("name"): + refs.append(secret["name"]) + return refs + + +class TestSecretScoping: + """Verify that the monolithic stonks-secrets is no longer used.""" + + def test_no_monolithic_secret_ref(self): + """No deployment should reference the old stonks-secrets.""" + for name, dep in _load_deployments(): + pod_spec = dep["spec"]["template"]["spec"] + refs = _get_secret_refs(pod_spec) + assert "stonks-secrets" not in refs, ( + f"Deployment {name} still references monolithic stonks-secrets" + ) + + def test_broker_secrets_only_on_trading_tier(self): + """Only broker-adapter and risk-engine should have broker secrets.""" + for name, dep in _load_deployments(): + pod_spec = dep["spec"]["template"]["spec"] + refs = _get_secret_refs(pod_spec) + if "stonks-broker-secrets" in refs: + assert name in BROKER_SECRET_ALLOWED, ( + f"Deployment {name} has broker secrets but is not in " + f"allowed set {BROKER_SECRET_ALLOWED}" + ) + + def test_market_secrets_only_on_ingestion(self): + """Only ingestion-worker should have market-data secrets.""" + for name, dep in _load_deployments(): + pod_spec = dep["spec"]["template"]["spec"] + refs = _get_secret_refs(pod_spec) + if "stonks-market-secrets" in refs: + assert name in MARKET_SECRET_ALLOWED, ( + f"Deployment {name} has market secrets but is not in " + f"allowed set {MARKET_SECRET_ALLOWED}" + ) + + +class TestPodSecurityContext: + """Verify pod-level security settings.""" + + def test_run_as_non_root(self): + for name, dep in _load_deployments(): + pod_sec = dep["spec"]["template"]["spec"].get("securityContext", {}) + assert pod_sec.get("runAsNonRoot") is True, ( + f"Deployment {name} missing runAsNonRoot: true" + ) + + def test_seccomp_profile(self): + for name, dep in _load_deployments(): + pod_sec = dep["spec"]["template"]["spec"].get("securityContext", {}) + seccomp = pod_sec.get("seccompProfile", {}) + assert seccomp.get("type") == "RuntimeDefault", ( + f"Deployment {name} missing seccompProfile RuntimeDefault" + ) + + def test_automount_service_account_disabled(self): + for name, dep in _load_deployments(): + pod_spec = dep["spec"]["template"]["spec"] + assert pod_spec.get("automountServiceAccountToken") is False, ( + f"Deployment {name} should set automountServiceAccountToken: false" + ) + + +class TestContainerSecurityContext: + """Verify container-level security settings.""" + + def test_no_privilege_escalation(self): + for name, dep in _load_deployments(): + for container in dep["spec"]["template"]["spec"]["containers"]: + sec = container.get("securityContext", {}) + assert sec.get("allowPrivilegeEscalation") is False, ( + f"Deployment {name}, container {container['name']} " + f"missing allowPrivilegeEscalation: false" + ) + + def test_drop_all_capabilities(self): + for name, dep in _load_deployments(): + for container in dep["spec"]["template"]["spec"]["containers"]: + sec = container.get("securityContext", {}) + caps = sec.get("capabilities", {}) + assert "ALL" in caps.get("drop", []), ( + f"Deployment {name}, container {container['name']} " + f"should drop ALL capabilities" + ) + + +class TestNetworkPolicies: + """Verify network policy manifests exist and cover key patterns.""" + + def _load_netpols(self) -> list[dict]: + policies = [] + for path in K8S_DIR.glob("*.yaml"): + with open(path) as f: + for doc in yaml.safe_load_all(f): + if doc and doc.get("kind") == "NetworkPolicy": + policies.append(doc) + return policies + + def test_default_deny_exists(self): + policies = self._load_netpols() + deny_policies = [ + p for p in policies + if p["metadata"]["name"] == "default-deny-ingress" + ] + assert len(deny_policies) == 1, "Missing default-deny-ingress NetworkPolicy" + + def test_broker_adapter_denied_ingress(self): + policies = self._load_netpols() + broker_policies = [ + p for p in policies + if p["spec"].get("podSelector", {}).get("matchLabels", {}).get("app") == "broker-adapter" + ] + assert len(broker_policies) >= 1, "Missing NetworkPolicy for broker-adapter" + # Should have empty ingress (deny all inbound) + for p in broker_policies: + assert p["spec"].get("ingress") == [] or p["spec"].get("ingress") is None, ( + "broker-adapter should deny all ingress" + ) + + def test_risk_engine_restricted_ingress(self): + policies = self._load_netpols() + risk_policies = [ + p for p in policies + if p["spec"].get("podSelector", {}).get("matchLabels", {}).get("app") == "risk-engine" + ] + assert len(risk_policies) >= 1, "Missing NetworkPolicy for risk-engine" + + +class TestSecretsManifest: + """Verify the secrets manifest uses scoped secrets.""" + + def _load_secrets(self) -> list[dict]: + secrets = [] + path = K8S_DIR / "secrets.yaml" + with open(path) as f: + for doc in yaml.safe_load_all(f): + if doc and doc.get("kind") == "Secret": + secrets.append(doc) + return secrets + + def test_scoped_secrets_exist(self): + secrets = self._load_secrets() + names = {s["metadata"]["name"] for s in secrets} + assert "stonks-core-secrets" in names + assert "stonks-broker-secrets" in names + assert "stonks-market-secrets" in names + assert "stonks-dashboard-secrets" in names + + def test_no_monolithic_secret(self): + secrets = self._load_secrets() + names = {s["metadata"]["name"] for s in secrets} + assert "stonks-secrets" not in names, ( + "Monolithic stonks-secrets should be replaced by scoped secrets" + ) + + def test_no_plaintext_defaults(self): + """Secret values should be REPLACE_ME placeholders, not real defaults.""" + secrets = self._load_secrets() + for secret in secrets: + for key, value in secret.get("stringData", {}).items(): + if value: # skip empty strings (e.g. REDIS_PASSWORD) + assert value != "changeme", ( + f"Secret {secret['metadata']['name']}.{key} " + f"still has 'changeme' default" + ) diff --git a/tests/test_lake_publication_validation.py b/tests/test_lake_publication_validation.py new file mode 100644 index 0000000..3be14bd --- /dev/null +++ b/tests/test_lake_publication_validation.py @@ -0,0 +1,603 @@ +"""Validate lake publication and Trino query correctness over partitioned MinIO datasets. + +Ensures that: +- PyArrow schemas in worker.py match the lakehouse DDL column definitions +- Iceberg DDL generated from PyArrow schemas is consistent with lakehouse DDL +- Partition layouts are Hive-compatible and discoverable by Trino +- Published Parquet files embed partition columns in the data +- Cross-table join keys used by views are present and type-consistent +- All 12 analytical fact tables have aligned schema definitions across layers + +Requirements: 9.4, 9.5, 10.1, 10.3, N4, N6 +Design ref: Section 5.2, 5.3, 7, 8.4 +""" +from __future__ import annotations + +import io +import re +from datetime import date, datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock + +import pyarrow as pa +import pyarrow.parquet as pq + +from services.lake_publisher.iceberg import ( + ICEBERG_CATALOG, + ICEBERG_SCHEMA, + TABLE_SCHEMAS, + IcebergTableDef, + _arrow_type_to_trino, + get_all_table_defs, + get_table_def, +) +from services.lake_publisher.partitions import ( + LAKEHOUSE_BUCKET, + TABLE_PARTITIONS, + WAREHOUSE_PREFIX, + partition_path, + partition_values, +) +from services.lake_publisher.worker import ( + COMPANY_EVENTS_SCHEMA, + DOCUMENTS_SCHEMA, + DOCUMENT_EXTRACTIONS_SCHEMA, + MARKET_BARS_SCHEMA, + MARKET_QUOTES_SCHEMA, + MODEL_PERFORMANCE_SCHEMA, + PNL_DAILY_SCHEMA, + POSITIONS_DAILY_SCHEMA, + PREDICTION_VS_OUTCOME_SCHEMA, + TRADE_FILLS_SCHEMA, + TRADE_ORDERS_SCHEMA, + TRADE_SIGNALS_SCHEMA, + publish_market_bar, + publish_document_fact, + publish_document_extraction, + publish_trade_signal, + publish_trade_order, + publish_trade_fill, + publish_position_daily, + publish_pnl_daily, + publish_company_event, + publish_market_quote, + publish_prediction_fact, + publish_model_performance, +) +from services.shared.schemas import ( + ActionType, + ModelMetadata, + PositionSizing, + Recommendation, + RecommendationMode, +) + +NOW = datetime(2026, 4, 11, 14, 30, 0, tzinfo=timezone.utc) +LAKEHOUSE_DDL_DIR = Path("lakehouse/schemas") + +# All 12 expected analytical fact tables +ALL_TABLES = [ + "market_bars", + "market_quotes", + "company_events", + "documents", + "document_extractions", + "trade_signals", + "trade_orders", + "trade_fills", + "positions_daily", + "pnl_daily", + "prediction_vs_outcome", + "model_performance", +] + +# Map table names to their PyArrow schemas for direct reference +PYARROW_SCHEMAS: dict[str, pa.Schema] = { + "market_bars": MARKET_BARS_SCHEMA, + "market_quotes": MARKET_QUOTES_SCHEMA, + "company_events": COMPANY_EVENTS_SCHEMA, + "documents": DOCUMENTS_SCHEMA, + "document_extractions": DOCUMENT_EXTRACTIONS_SCHEMA, + "trade_signals": TRADE_SIGNALS_SCHEMA, + "trade_orders": TRADE_ORDERS_SCHEMA, + "trade_fills": TRADE_FILLS_SCHEMA, + "positions_daily": POSITIONS_DAILY_SCHEMA, + "pnl_daily": PNL_DAILY_SCHEMA, + "prediction_vs_outcome": PREDICTION_VS_OUTCOME_SCHEMA, + "model_performance": MODEL_PERFORMANCE_SCHEMA, +} + + +# --------------------------------------------------------------------------- +# Helpers: parse lakehouse DDL SQL files +# --------------------------------------------------------------------------- + +def _parse_ddl_columns(sql_path: Path) -> list[tuple[str, str]]: + """Parse column definitions from a lakehouse DDL SQL file. + + Returns list of (column_name, trino_type) tuples in declaration order. + Includes partition columns from the partitioned_by clause appended at the end, + since Hive DDL separates them but PyArrow/Iceberg schemas include them inline. + """ + text = sql_path.read_text() + # Extract the column block — match balanced parens for the CREATE TABLE body. + # The column block ends at the closing ) before WITH. + match = re.search( + r"CREATE TABLE[^(]+\((.*)\)\s*WITH", + text, re.DOTALL | re.IGNORECASE, + ) + if not match: + return [] + col_block = match.group(1) + columns = [] + for line in col_block.strip().split("\n"): + line = line.strip().rstrip(",") + if not line or line.startswith("--"): + continue + # Split only on first whitespace to keep multi-word types intact + parts = line.split(None, 1) + if len(parts) >= 2: + col_name = parts[0].lower() + col_type = parts[1].upper().strip() + columns.append((col_name, col_type)) + return columns + + +def _parse_ddl_partitions(sql_path: Path) -> list[str]: + """Parse partition keys from a lakehouse DDL SQL file.""" + text = sql_path.read_text() + match = re.search(r"partitioned_by\s*=\s*ARRAY\[([^\]]+)\]", text, re.IGNORECASE) + if not match: + return [] + raw = match.group(1) + return [k.strip().strip("'\"") for k in raw.split(",")] + + +# --------------------------------------------------------------------------- +# 1. All 12 tables are registered across all layers +# --------------------------------------------------------------------------- + + +def test_all_tables_in_partition_registry(): + """Every expected analytical table is registered in TABLE_PARTITIONS.""" + for table in ALL_TABLES: + assert table in TABLE_PARTITIONS, f"{table} missing from TABLE_PARTITIONS" + + +def test_all_tables_in_schema_registry(): + """Every expected analytical table has a PyArrow schema in TABLE_SCHEMAS.""" + for table in ALL_TABLES: + assert table in TABLE_SCHEMAS, f"{table} missing from TABLE_SCHEMAS" + + +def test_all_tables_have_ddl_files(): + """Every expected analytical table has a lakehouse DDL SQL file.""" + for table in ALL_TABLES: + ddl_path = LAKEHOUSE_DDL_DIR / f"{table}.sql" + assert ddl_path.exists(), f"Missing DDL file: {ddl_path}" + + +def test_all_tables_have_iceberg_defs(): + """Every table in TABLE_PARTITIONS produces a valid IcebergTableDef.""" + defs = get_all_table_defs() + def_names = {d.table_name for d in defs} + for table in ALL_TABLES: + assert table in def_names, f"{table} missing from Iceberg table defs" + + +# --------------------------------------------------------------------------- +# 2. PyArrow schema ↔ Lakehouse DDL column alignment +# --------------------------------------------------------------------------- + + +def test_pyarrow_columns_match_ddl(): + """PyArrow schema column names and order match the lakehouse DDL for every table.""" + for table in ALL_TABLES: + ddl_path = LAKEHOUSE_DDL_DIR / f"{table}.sql" + if not ddl_path.exists(): + continue + ddl_cols = _parse_ddl_columns(ddl_path) + ddl_col_names = [c[0] for c in ddl_cols] + + arrow_schema = PYARROW_SCHEMAS[table] + arrow_col_names = [arrow_schema.field(i).name for i in range(len(arrow_schema))] + + assert arrow_col_names == ddl_col_names, ( + f"Column mismatch for {table}:\n" + f" PyArrow: {arrow_col_names}\n" + f" DDL: {ddl_col_names}" + ) + + +def test_pyarrow_types_compatible_with_ddl(): + """PyArrow types map to Trino types that match the lakehouse DDL.""" + for table in ALL_TABLES: + ddl_path = LAKEHOUSE_DDL_DIR / f"{table}.sql" + if not ddl_path.exists(): + continue + ddl_cols = _parse_ddl_columns(ddl_path) + ddl_type_map = {name: typ for name, typ in ddl_cols} + + arrow_schema = PYARROW_SCHEMAS[table] + for i in range(len(arrow_schema)): + col_name = arrow_schema.field(i).name + arrow_type = arrow_schema.field(i).type + trino_type = _arrow_type_to_trino(arrow_type) + + ddl_type = ddl_type_map.get(col_name, "") + assert trino_type == ddl_type, ( + f"Type mismatch for {table}.{col_name}: " + f"PyArrow→Trino={trino_type}, DDL={ddl_type}" + ) + + +# --------------------------------------------------------------------------- +# 3. Partition key alignment across layers +# --------------------------------------------------------------------------- + + +def test_partition_keys_match_ddl(): + """Partition keys in TABLE_PARTITIONS match the DDL partitioned_by clause.""" + for table in ALL_TABLES: + ddl_path = LAKEHOUSE_DDL_DIR / f"{table}.sql" + if not ddl_path.exists(): + continue + ddl_parts = _parse_ddl_partitions(ddl_path) + spec = TABLE_PARTITIONS[table] + arrow_parts = list(spec.all_keys) + + assert arrow_parts == ddl_parts, ( + f"Partition key mismatch for {table}: " + f"TABLE_PARTITIONS={arrow_parts}, DDL={ddl_parts}" + ) + + +def test_iceberg_partition_keys_match(): + """Iceberg DDL partition keys match TABLE_PARTITIONS for every table.""" + for td in get_all_table_defs(): + spec = TABLE_PARTITIONS[td.table_name] + expected_keys = list(spec.all_keys) + # Parse from the generated SQL + sql = td.create_table_sql() + match = re.search(r"partitioning = ARRAY\[([^\]]+)\]", sql) + if expected_keys: + assert match is not None, f"No partitioning clause for {td.table_name}" + parsed = [k.strip().strip("'") for k in match.group(1).split(",")] + assert parsed == expected_keys, ( + f"Iceberg partition mismatch for {td.table_name}: " + f"expected={expected_keys}, got={parsed}" + ) + + +# --------------------------------------------------------------------------- +# 4. Partition columns are embedded in PyArrow schemas +# --------------------------------------------------------------------------- + + +def test_partition_columns_in_pyarrow_schemas(): + """Partition columns (dt, model_version, etc.) appear in the PyArrow schema + so they are written into Parquet files, not just inferred from paths.""" + for table in ALL_TABLES: + schema = PYARROW_SCHEMAS[table] + spec = TABLE_PARTITIONS[table] + col_names = {schema.field(i).name for i in range(len(schema))} + for key in spec.all_keys: + assert key in col_names, ( + f"Partition column '{key}' missing from PyArrow schema for {table}" + ) + + +# --------------------------------------------------------------------------- +# 5. Hive-compatible partition path format +# --------------------------------------------------------------------------- + + +def test_partition_paths_are_hive_compatible(): + """Partition paths follow Hive key=value directory convention.""" + for table in ALL_TABLES: + spec = TABLE_PARTITIONS[table] + extras = {} + if spec.extra_keys: + extras = {k: "test_val" for k in spec.extra_keys} + path = partition_path(table, NOW, extras) + + # Must start with warehouse prefix + assert path.startswith(f"{WAREHOUSE_PREFIX}/{table}/"), ( + f"Path for {table} doesn't start with warehouse prefix: {path}" + ) + # Must contain dt= partition + assert "dt=2026-04-11" in path, f"Missing dt partition in path for {table}: {path}" + # Must end with .parquet + assert path.endswith(".parquet"), f"Path for {table} doesn't end with .parquet: {path}" + # Extra partition keys must appear + for key in spec.extra_keys: + assert f"{key}=test_val" in path, ( + f"Missing extra partition {key} in path for {table}: {path}" + ) + + +def test_partition_path_dt_from_date_object(): + """partition_path works with both datetime and date objects.""" + d = date(2026, 4, 11) + path = partition_path("market_bars", d) + assert "dt=2026-04-11" in path + + +# --------------------------------------------------------------------------- +# 6. Published Parquet files contain partition columns in data +# --------------------------------------------------------------------------- + + +def _capture_parquet(mock_client: MagicMock) -> pa.Table: + """Extract the Parquet table from a MagicMock MinIO client's put_object call.""" + put_call = mock_client.put_object.call_args + buf = put_call[0][2] + buf.seek(0) + return pq.read_table(buf) + + +def test_published_market_bar_has_dt_column(): + client = MagicMock() + publish_market_bar( + client, ticker="AAPL", open_price=150.0, high_price=155.0, + low_price=149.0, close_price=153.0, volume=1000000, + bar_timestamp=NOW, source="test", + ) + table = _capture_parquet(client) + assert "dt" in table.column_names + assert table.column("dt")[0].as_py() == date(2026, 4, 11) + + +def test_published_document_extraction_has_partition_columns(): + client = MagicMock() + publish_document_extraction( + client, document_id="doc-1", ticker="AAPL", sentiment="positive", + impact_score=0.7, catalyst_type="earnings", confidence=0.85, + extraction_at=NOW, model_name="test-model", prompt_version="v1", + schema_version="2.0.0", + ) + table = _capture_parquet(client) + assert "dt" in table.column_names + assert "model_version" in table.column_names + assert table.column("dt")[0].as_py() == date(2026, 4, 11) + assert table.column("model_version")[0].as_py() == "2.0.0" + + +def test_published_prediction_vs_outcome_has_partition_columns(): + client = MagicMock() + rec = Recommendation( + recommendation_id="rec-001", ticker="AAPL", action=ActionType.BUY, + mode=RecommendationMode.PAPER_ELIGIBLE, confidence=0.72, + time_horizon="swing_1d_10d", thesis="test", + invalidation_conditions=["x"], position_sizing=PositionSizing(portfolio_pct=0.02, max_loss_pct=0.005), + evidence_refs=["doc1"], model_metadata=ModelMetadata(provider="ollama", model_name="test-v1"), + generated_at=NOW, + ) + publish_prediction_fact(client, rec) + table = _capture_parquet(client) + assert "dt" in table.column_names + assert "model_version" in table.column_names + + +def test_published_model_performance_has_partition_columns(): + client = MagicMock() + publish_model_performance( + client, document_id="doc-1", model_name="gpt-oss:20b", + success=True, total_duration_ms=1500, recorded_at=NOW, + schema_version="2.0.0", + ) + table = _capture_parquet(client) + assert "dt" in table.column_names + assert "model_version" in table.column_names + assert table.column("model_version")[0].as_py() == "2.0.0" + + +# --------------------------------------------------------------------------- +# 7. Parquet schema matches PyArrow schema for every publisher +# --------------------------------------------------------------------------- + + +def _publish_and_verify_schema(table_name: str, publish_fn, expected_schema: pa.Schema): + """Helper: call a publish function, read back the Parquet, verify column names match.""" + client = MagicMock() + publish_fn(client) + table = _capture_parquet(client) + expected_names = [expected_schema.field(i).name for i in range(len(expected_schema))] + assert list(table.column_names) == expected_names, ( + f"Parquet column mismatch for {table_name}: " + f"got={list(table.column_names)}, expected={expected_names}" + ) + + +def test_parquet_schema_market_bars(): + _publish_and_verify_schema("market_bars", lambda c: publish_market_bar( + c, "AAPL", 150.0, 155.0, 149.0, 153.0, 1000000, NOW, "test", + ), MARKET_BARS_SCHEMA) + + +def test_parquet_schema_market_quotes(): + _publish_and_verify_schema("market_quotes", lambda c: publish_market_quote( + c, "AAPL", 150.0, 150.5, 150.25, NOW, "test", + ), MARKET_QUOTES_SCHEMA) + + +def test_parquet_schema_company_events(): + _publish_and_verify_schema("company_events", lambda c: publish_company_event( + c, "evt-1", "AAPL", "earnings", "Q1 Earnings", NOW, "test", + ), COMPANY_EVENTS_SCHEMA) + + +def test_parquet_schema_documents(): + _publish_and_verify_schema("documents", lambda c: publish_document_fact( + c, "doc-1", "article", "news_api", "AAPL", "Reuters", "Test", NOW, "hash123", + ), DOCUMENTS_SCHEMA) + + +def test_parquet_schema_trade_orders(): + _publish_and_verify_schema("trade_orders", lambda c: publish_trade_order( + c, "ord-1", "AAPL", "buy", "market", 10.0, None, "filled", "acct-1", NOW, + ), TRADE_ORDERS_SCHEMA) + + +def test_parquet_schema_trade_fills(): + _publish_and_verify_schema("trade_fills", lambda c: publish_trade_fill( + c, "fill-1", "ord-1", "AAPL", "buy", 150.25, 10.0, "acct-1", NOW, + ), TRADE_FILLS_SCHEMA) + + +def test_parquet_schema_positions_daily(): + _publish_and_verify_schema("positions_daily", lambda c: publish_position_daily( + c, "AAPL", 100.0, 145.0, 150.0, 500.0, "acct-1", NOW, + ), POSITIONS_DAILY_SCHEMA) + + +def test_parquet_schema_pnl_daily(): + _publish_and_verify_schema("pnl_daily", lambda c: publish_pnl_daily( + c, "AAPL", 200.0, 500.0, 700.0, "acct-1", NOW, + ), PNL_DAILY_SCHEMA) + + +# --------------------------------------------------------------------------- +# 8. Cross-table join keys for views +# --------------------------------------------------------------------------- + + +def test_prediction_accuracy_view_join_keys(): + """prediction_accuracy view joins prediction_vs_outcome with trade_signals + on recommendation_id and dt — both tables must have these columns.""" + pvo_cols = {PREDICTION_VS_OUTCOME_SCHEMA.field(i).name for i in range(len(PREDICTION_VS_OUTCOME_SCHEMA))} + ts_cols = {TRADE_SIGNALS_SCHEMA.field(i).name for i in range(len(TRADE_SIGNALS_SCHEMA))} + assert "recommendation_id" in pvo_cols + assert "recommendation_id" in ts_cols + assert "dt" in pvo_cols + assert "dt" in ts_cols + + +def test_paper_trade_scorecard_view_join_keys(): + """paper_trade_scorecard joins pnl_daily with trade_orders + on ticker, broker_account, and dt.""" + pnl_cols = {PNL_DAILY_SCHEMA.field(i).name for i in range(len(PNL_DAILY_SCHEMA))} + ord_cols = {TRADE_ORDERS_SCHEMA.field(i).name for i in range(len(TRADE_ORDERS_SCHEMA))} + for key in ["ticker", "broker_account", "dt"]: + assert key in pnl_cols, f"pnl_daily missing join key: {key}" + assert key in ord_cols, f"trade_orders missing join key: {key}" + + +def test_paper_trade_detail_view_join_keys(): + """paper_trade_detail joins trade_orders, trade_fills, and prediction_vs_outcome.""" + ord_cols = {TRADE_ORDERS_SCHEMA.field(i).name for i in range(len(TRADE_ORDERS_SCHEMA))} + fill_cols = {TRADE_FILLS_SCHEMA.field(i).name for i in range(len(TRADE_FILLS_SCHEMA))} + pvo_cols = {PREDICTION_VS_OUTCOME_SCHEMA.field(i).name for i in range(len(PREDICTION_VS_OUTCOME_SCHEMA))} + + # orders ↔ fills on order_id, dt + assert "order_id" in ord_cols + assert "order_id" in fill_cols + assert "dt" in ord_cols + assert "dt" in fill_cols + + # orders ↔ prediction_vs_outcome on recommendation_id, dt + assert "recommendation_id" in ord_cols + assert "recommendation_id" in pvo_cols + + +def test_signal_hit_rate_view_columns(): + """signal_hit_rate groups by dt and model_version from prediction_vs_outcome.""" + pvo_cols = {PREDICTION_VS_OUTCOME_SCHEMA.field(i).name for i in range(len(PREDICTION_VS_OUTCOME_SCHEMA))} + assert "dt" in pvo_cols + assert "model_version" in pvo_cols + assert "outcome" in pvo_cols + assert "predicted_confidence" in pvo_cols + assert "actual_move_pct" in pvo_cols + + +# --------------------------------------------------------------------------- +# 9. Iceberg DDL consistency with lakehouse DDL +# --------------------------------------------------------------------------- + + +def test_iceberg_ddl_columns_match_lakehouse_ddl(): + """Iceberg CREATE TABLE columns match the lakehouse DDL columns for every table.""" + for td in get_all_table_defs(): + ddl_path = LAKEHOUSE_DDL_DIR / f"{td.table_name}.sql" + if not ddl_path.exists(): + continue + ddl_cols = _parse_ddl_columns(ddl_path) + ddl_col_names = [c[0] for c in ddl_cols] + + iceberg_sql = td.create_table_sql() + # Extract column block from Iceberg DDL (greedy to handle nested parens) + match = re.search(r"CREATE TABLE[^(]+\((.*)\)\s*WITH", iceberg_sql, re.DOTALL) + assert match is not None, f"Could not parse Iceberg DDL for {td.table_name}" + iceberg_col_block = match.group(1) + iceberg_col_names = [] + for line in iceberg_col_block.strip().split("\n"): + line = line.strip().rstrip(",") + if line: + parts = line.split() + if parts: + iceberg_col_names.append(parts[0].lower()) + + assert iceberg_col_names == ddl_col_names, ( + f"Iceberg DDL column mismatch for {td.table_name}:\n" + f" Iceberg: {iceberg_col_names}\n" + f" DDL: {ddl_col_names}" + ) + + +# --------------------------------------------------------------------------- +# 10. MinIO bucket and path conventions +# --------------------------------------------------------------------------- + + +def test_lakehouse_bucket_name(): + assert LAKEHOUSE_BUCKET == "stonks-lakehouse" + + +def test_warehouse_prefix(): + assert WAREHOUSE_PREFIX == "warehouse" + + +def test_all_paths_use_warehouse_prefix(): + """Every table's partition path starts with warehouse/{table_name}/.""" + for table in ALL_TABLES: + spec = TABLE_PARTITIONS[table] + extras = {k: "v" for k in spec.extra_keys} + path = partition_path(table, NOW, extras) + assert path.startswith(f"warehouse/{table}/"), ( + f"Path for {table} doesn't follow convention: {path}" + ) + + +# --------------------------------------------------------------------------- +# 11. Iceberg table locations point to correct MinIO paths +# --------------------------------------------------------------------------- + + +def test_iceberg_locations_match_ddl_external_locations(): + """Iceberg table locations use s3a:// and match the lakehouse DDL external_location.""" + for td in get_all_table_defs(): + expected = f"s3a://{LAKEHOUSE_BUCKET}/{WAREHOUSE_PREFIX}/{td.table_name}/" + assert td.location == expected, ( + f"Iceberg location mismatch for {td.table_name}: " + f"got={td.location}, expected={expected}" + ) + + +# --------------------------------------------------------------------------- +# 12. Partition values are injected correctly +# --------------------------------------------------------------------------- + + +def test_partition_values_dt_only(): + pv = partition_values(NOW) + assert pv == {"dt": date(2026, 4, 11)} + + +def test_partition_values_with_model_version(): + pv = partition_values(NOW, {"model_version": "2.0.0"}) + assert pv == {"dt": date(2026, 4, 11), "model_version": "2.0.0"} + + +def test_partition_values_from_date(): + pv = partition_values(date(2026, 4, 11)) + assert pv == {"dt": date(2026, 4, 11)} diff --git a/tests/test_lake_publisher.py b/tests/test_lake_publisher.py new file mode 100644 index 0000000..740ecff --- /dev/null +++ b/tests/test_lake_publisher.py @@ -0,0 +1,596 @@ +"""Tests for lake publisher worker — writing prediction facts as Parquet to MinIO.""" +from datetime import date, datetime, timezone +from unittest.mock import MagicMock + +import pyarrow.parquet as pq + +from services.lake_publisher.partitions import ( + LAKEHOUSE_BUCKET, + TABLE_PARTITIONS, + PartitionSpec, + partition_path, + partition_values, + s3_uri, +) +from services.lake_publisher.worker import ( + _parse_horizon_days, + _partition_path, + build_trade_signal_row, + publish_trade_signal, + publish_prediction_fact, + publish_recommendation_facts, + build_trade_order_row, + publish_trade_order, + build_trade_fill_row, + publish_trade_fill, + build_position_daily_row, + publish_position_daily, + publish_positions_daily_batch, + build_model_performance_row, + publish_model_performance, + publish_market_bars_batch, + publish_trade_signals_batch, + publish_model_performance_batch, +) +from services.shared.schemas import ( + ActionType, + ModelMetadata, + PositionSizing, + Recommendation, + RecommendationMode, +) + +NOW = datetime(2026, 4, 11, 14, 30, 0, tzinfo=timezone.utc) + + +def _make_rec( + ticker: str = "AAPL", + action: ActionType = ActionType.BUY, + confidence: float = 0.72, + time_horizon: str = "swing_1d_10d", + rec_id: str = "rec-001", +) -> Recommendation: + return Recommendation( + recommendation_id=rec_id, + ticker=ticker, + action=action, + mode=RecommendationMode.PAPER_ELIGIBLE, + confidence=confidence, + time_horizon=time_horizon, + thesis="[risk:low] Test thesis", + invalidation_conditions=["price drops below support"], + position_sizing=PositionSizing(portfolio_pct=0.02, max_loss_pct=0.005), + evidence_refs=["doc1", "doc2"], + model_metadata=ModelMetadata(provider="deterministic", model_name="eligibility-v1"), + generated_at=NOW, + ) + + +# --------------------------------------------------------------------------- +# _parse_horizon_days +# --------------------------------------------------------------------------- + + +def test_parse_horizon_days_swing(): + assert _parse_horizon_days("swing_1d_10d") == 10 + + +def test_parse_horizon_days_position(): + assert _parse_horizon_days("position_10d_30d") == 30 + + +def test_parse_horizon_days_intraday(): + assert _parse_horizon_days("scalp_intraday") == 1 + + +def test_parse_horizon_days_empty(): + assert _parse_horizon_days("") == 0 + + +def test_parse_horizon_days_no_numbers(): + assert _parse_horizon_days("unknown") == 0 + + +# --------------------------------------------------------------------------- +# Partition module tests +# --------------------------------------------------------------------------- + + +def test_partition_spec_all_keys(): + spec = PartitionSpec("test_table", extra_keys=("model_version",)) + assert spec.all_keys == ("dt", "model_version") + + +def test_partition_spec_dt_only(): + spec = PartitionSpec("simple") + assert spec.all_keys == ("dt",) + + +def test_table_partitions_registry(): + assert "market_bars" in TABLE_PARTITIONS + assert "model_performance" in TABLE_PARTITIONS + assert TABLE_PARTITIONS["model_performance"].extra_keys == ("model_version",) + assert TABLE_PARTITIONS["prediction_vs_outcome"].extra_keys == ("model_version",) + assert TABLE_PARTITIONS["document_extractions"].extra_keys == ("model_version",) + assert TABLE_PARTITIONS["trade_signals"].extra_keys == () + + +def test_partition_values_dt_only(): + pv = partition_values(NOW) + assert pv == {"dt": date(2026, 4, 11)} + + +def test_partition_values_with_extras(): + pv = partition_values(NOW, {"model_version": "v2"}) + assert pv == {"dt": date(2026, 4, 11), "model_version": "v2"} + + +def test_s3_uri(): + assert s3_uri("warehouse/t/dt=2026-04-11/part-abc.parquet") == \ + "s3://stonks-lakehouse/warehouse/t/dt=2026-04-11/part-abc.parquet" + + +# --------------------------------------------------------------------------- +# _partition_path (via partitions module) +# --------------------------------------------------------------------------- + + +def test_partition_path_basic(): + path = partition_path("trade_signals", NOW) + assert path.startswith("warehouse/trade_signals/dt=2026-04-11/") + assert path.endswith(".parquet") + + +def test_partition_path_with_extra_partitions(): + path = partition_path("model_performance", NOW, {"model_version": "v1"}) + assert "model_version=v1" in path + + +def test_partition_path_custom_file_id(): + path = partition_path("trade_signals", NOW, file_id="custom123") + assert "part-custom123.parquet" in path + + +def test_partition_path_legacy_wrapper(): + """The _partition_path wrapper in worker.py still works.""" + path = _partition_path("trade_signals", NOW) + assert path.startswith("warehouse/trade_signals/dt=2026-04-11/") + + +# --------------------------------------------------------------------------- +# build_trade_signal_row +# --------------------------------------------------------------------------- + + +def test_build_trade_signal_row(): + rec = _make_rec() + row = build_trade_signal_row(rec, trend_direction="bullish", trend_strength=0.68) + assert row["signal_id"] == "rec-001" + assert row["ticker"] == "AAPL" + assert row["trend_direction"] == "bullish" + assert row["trend_strength"] == 0.68 + assert row["confidence"] == 0.72 + assert row["action"] == "buy" + assert row["time_horizon"] == "swing_1d_10d" + assert row["generated_at"] == NOW + assert row["dt"] == date(2026, 4, 11) + + +# --------------------------------------------------------------------------- +# publish_trade_signal +# --------------------------------------------------------------------------- + + +def test_publish_trade_signal_writes_parquet(): + client = MagicMock() + rec = _make_rec() + + ref = publish_trade_signal(client, rec, trend_direction="bullish", trend_strength=0.68) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_signals/") + assert client.put_object.call_count == 1 + + # Verify the written bytes are valid Parquet + put_call = client.put_object.call_args + assert put_call[0][0] == LAKEHOUSE_BUCKET + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 1 + assert table.column("ticker")[0].as_py() == "AAPL" + assert table.column("action")[0].as_py() == "buy" + + +# --------------------------------------------------------------------------- +# publish_prediction_fact +# --------------------------------------------------------------------------- + + +def test_publish_prediction_fact_writes_parquet(): + client = MagicMock() + rec = _make_rec() + + ref = publish_prediction_fact(client, rec) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/prediction_vs_outcome/") + assert "model_version=" in ref + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 1 + assert table.column("predicted_action")[0].as_py() == "buy" + assert table.column("outcome")[0].as_py() == "pending" + assert table.column("horizon_days")[0].as_py() == 10 + assert table.column("dt")[0].as_py() == date(2026, 4, 11) + + +# --------------------------------------------------------------------------- +# publish_recommendation_facts +# --------------------------------------------------------------------------- + + +def test_publish_recommendation_facts_writes_both_tables(): + client = MagicMock() + rec = _make_rec() + + refs = publish_recommendation_facts(client, rec, "bullish", 0.68) + + assert "trade_signals" in refs + assert "prediction_vs_outcome" in refs + assert client.put_object.call_count == 2 + + +# --------------------------------------------------------------------------- +# build_trade_order_row +# --------------------------------------------------------------------------- + + +def test_build_trade_order_row(): + row = build_trade_order_row( + order_id="ord-001", + ticker="AAPL", + side="buy", + order_type="market", + quantity=10.0, + limit_price=None, + status="filled", + broker_account="acct-001", + submitted_at=NOW, + ) + assert row["order_id"] == "ord-001" + assert row["ticker"] == "AAPL" + assert row["side"] == "buy" + assert row["order_type"] == "market" + assert row["quantity"] == 10.0 + assert row["limit_price"] is None + assert row["status"] == "filled" + assert row["broker_account"] == "acct-001" + assert row["submitted_at"] == NOW + assert row["dt"] == date(2026, 4, 11) + + +# --------------------------------------------------------------------------- +# publish_trade_order +# --------------------------------------------------------------------------- + + +def test_publish_trade_order_writes_parquet(): + client = MagicMock() + + ref = publish_trade_order( + client, + order_id="ord-001", + ticker="AAPL", + side="buy", + order_type="market", + quantity=10.0, + limit_price=None, + status="filled", + broker_account="acct-001", + submitted_at=NOW, + ) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_orders/") + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + assert put_call[0][0] == LAKEHOUSE_BUCKET + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 1 + assert table.column("ticker")[0].as_py() == "AAPL" + assert table.column("side")[0].as_py() == "buy" + assert table.column("status")[0].as_py() == "filled" + + +# --------------------------------------------------------------------------- +# build_trade_fill_row +# --------------------------------------------------------------------------- + + +def test_build_trade_fill_row(): + row = build_trade_fill_row( + fill_id="fill-001", + order_id="ord-001", + ticker="AAPL", + side="buy", + fill_price=150.25, + fill_quantity=10.0, + broker_account="acct-001", + filled_at=NOW, + ) + assert row["fill_id"] == "fill-001" + assert row["order_id"] == "ord-001" + assert row["fill_price"] == 150.25 + assert row["fill_quantity"] == 10.0 + + +# --------------------------------------------------------------------------- +# publish_trade_fill +# --------------------------------------------------------------------------- + + +def test_publish_trade_fill_writes_parquet(): + client = MagicMock() + + ref = publish_trade_fill( + client, + fill_id="fill-001", + order_id="ord-001", + ticker="AAPL", + side="buy", + fill_price=150.25, + fill_quantity=10.0, + broker_account="acct-001", + filled_at=NOW, + ) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_fills/") + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 1 + assert table.column("fill_price")[0].as_py() == 150.25 + assert table.column("ticker")[0].as_py() == "AAPL" + + +# --------------------------------------------------------------------------- +# build_position_daily_row +# --------------------------------------------------------------------------- + + +def test_build_position_daily_row(): + row = build_position_daily_row( + ticker="AAPL", + quantity=100.0, + avg_entry_price=145.00, + close_price=150.00, + unrealized_pnl=500.0, + broker_account="acct-001", + snapshot_at=NOW, + ) + assert row["ticker"] == "AAPL" + assert row["quantity"] == 100.0 + assert row["unrealized_pnl"] == 500.0 + + +# --------------------------------------------------------------------------- +# publish_position_daily +# --------------------------------------------------------------------------- + + +def test_publish_position_daily_writes_parquet(): + client = MagicMock() + + ref = publish_position_daily( + client, + ticker="AAPL", + quantity=100.0, + avg_entry_price=145.00, + close_price=150.00, + unrealized_pnl=500.0, + broker_account="acct-001", + snapshot_at=NOW, + ) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/positions_daily/") + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 1 + assert table.column("ticker")[0].as_py() == "AAPL" + assert table.column("close_price")[0].as_py() == 150.00 + + +# --------------------------------------------------------------------------- +# publish_positions_daily_batch +# --------------------------------------------------------------------------- + + +def test_publish_positions_daily_batch_writes_parquet(): + client = MagicMock() + + positions = [ + {"ticker": "AAPL", "quantity": 100.0, "avg_entry_price": 145.0, "close_price": 150.0, "unrealized_pnl": 500.0}, + {"ticker": "MSFT", "quantity": 50.0, "avg_entry_price": 300.0, "close_price": 310.0, "unrealized_pnl": 500.0}, + ] + + ref = publish_positions_daily_batch(client, positions, "acct-001", NOW) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/positions_daily/") + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 2 + + +def test_publish_positions_daily_batch_empty(): + client = MagicMock() + + ref = publish_positions_daily_batch(client, [], "acct-001", NOW) + + assert ref == "" + assert client.put_object.call_count == 0 + + +# --------------------------------------------------------------------------- +# build_model_performance_row +# --------------------------------------------------------------------------- + + +def test_build_model_performance_row(): + row = build_model_performance_row( + document_id="doc-001", + model_name="gpt-oss:20b", + success=True, + total_duration_ms=1500, + recorded_at=NOW, + ticker="AAPL", + prompt_version="document-intel-v2", + schema_version="2.0.0", + attempt_count=2, + confidence=0.86, + validation_status="valid", + retry_count=1, + input_token_estimate=500, + output_token_estimate=200, + company_count=3, + ) + assert row["document_id"] == "doc-001" + assert row["model_name"] == "gpt-oss:20b" + assert row["success"] is True + assert row["total_duration_ms"] == 1500 + assert row["attempt_count"] == 2 + assert row["confidence"] == 0.86 + assert row["company_count"] == 3 + assert row["dt"] == date(2026, 4, 11) + assert row["model_version"] == "2.0.0" + + +# --------------------------------------------------------------------------- +# publish_model_performance +# --------------------------------------------------------------------------- + + +def test_publish_model_performance_writes_parquet(): + client = MagicMock() + + ref = publish_model_performance( + client, + document_id="doc-001", + model_name="gpt-oss:20b", + success=True, + total_duration_ms=1500, + recorded_at=NOW, + ticker="AAPL", + prompt_version="document-intel-v2", + schema_version="2.0.0", + confidence=0.86, + validation_status="valid", + ) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/model_performance/") + assert "model_version=2.0.0" in ref + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 1 + assert table.column("model_name")[0].as_py() == "gpt-oss:20b" + assert table.column("success")[0].as_py() is True + assert table.column("confidence")[0].as_py() == 0.86 + + +# --------------------------------------------------------------------------- +# Batch publish helpers +# --------------------------------------------------------------------------- + + +def test_publish_market_bars_batch(): + client = MagicMock() + bars: list[dict[str, object]] = [ + { + "ticker": "AAPL", "open_price": 150.0, "high_price": 155.0, + "low_price": 149.0, "close_price": 153.0, "volume": 1000000, + "vwap": 152.0, "trade_count": 5000, "bar_timestamp": NOW, + "bar_interval": "1d", "source": "test", + "dt": date(2026, 4, 11), + }, + { + "ticker": "MSFT", "open_price": 300.0, "high_price": 310.0, + "low_price": 298.0, "close_price": 305.0, "volume": 800000, + "vwap": 304.0, "trade_count": 4000, "bar_timestamp": NOW, + "bar_interval": "1d", "source": "test", + "dt": date(2026, 4, 11), + }, + ] + + ref = publish_market_bars_batch(client, bars, NOW) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/market_bars/") + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 2 + + +def test_publish_batch_empty_returns_empty(): + client = MagicMock() + ref = publish_market_bars_batch(client, [], NOW) + assert ref == "" + assert client.put_object.call_count == 0 + + +def test_publish_trade_signals_batch(): + client = MagicMock() + rec = _make_rec() + rows = [build_trade_signal_row(rec, "bullish", 0.68)] + + ref = publish_trade_signals_batch(client, rows, NOW) + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_signals/") + assert client.put_object.call_count == 1 + + +def test_publish_model_performance_batch(): + client = MagicMock() + rows = [ + build_model_performance_row( + document_id="doc-001", model_name="gpt-oss:20b", + success=True, total_duration_ms=1500, recorded_at=NOW, + ), + build_model_performance_row( + document_id="doc-002", model_name="gpt-oss:20b", + success=False, total_duration_ms=3000, recorded_at=NOW, + ), + ] + + ref = publish_model_performance_batch(client, rows, NOW, model_version="v2") + + assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/model_performance/") + assert "model_version=v2" in ref + assert client.put_object.call_count == 1 + + put_call = client.put_object.call_args + written_buf = put_call[0][2] + written_buf.seek(0) + table = pq.read_table(written_buf) + assert table.num_rows == 2 diff --git a/tests/test_lake_publisher_jobs.py b/tests/test_lake_publisher_jobs.py new file mode 100644 index 0000000..5b9423a --- /dev/null +++ b/tests/test_lake_publisher_jobs.py @@ -0,0 +1,355 @@ +"""Tests for lake publisher job runner — dispatching operational data to analytical facts.""" +from __future__ import annotations + +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.lake_publisher.jobs import ( + _jsonb_to_str, + dispatch_job, + publish_document_job, + publish_extraction_job, + publish_market_snapshot_job, + publish_order_job, + publish_fills_job, + publish_positions_job, + publish_pnl_job, + publish_bulk_documents_job, + publish_bulk_extractions_job, +) + +NOW = datetime(2026, 4, 11, 14, 30, 0, tzinfo=timezone.utc) + + +# --------------------------------------------------------------------------- +# _jsonb_to_str +# --------------------------------------------------------------------------- + + +def test_jsonb_to_str_list(): + assert _jsonb_to_str(["a", "b", "c"]) == "a, b, c" + + +def test_jsonb_to_str_json_string(): + assert _jsonb_to_str('["x", "y"]') == "x, y" + + +def test_jsonb_to_str_plain_string(): + assert _jsonb_to_str("hello") == "hello" + + +def test_jsonb_to_str_none(): + assert _jsonb_to_str(None) == "" + + +# --------------------------------------------------------------------------- +# publish_document_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_document_job_found(): + pool = AsyncMock() + pool.fetchrow.return_value = { + "id": "doc-uuid-1", + "document_type": "article", + "source_type": "news_api", + "publisher": "Reuters", + "title": "Test Article", + "url": "https://example.com/article", + "canonical_url": "https://example.com/article", + "language": "en", + "published_at": NOW, + "retrieved_at": NOW, + "content_hash": "abc123", + "parse_quality_score": 0.85, + "ticker": "AAPL", + } + minio_client = MagicMock() + + ref = await publish_document_job(pool, minio_client, "doc-uuid-1") + + assert ref.startswith("s3://stonks-lakehouse/warehouse/documents/") + assert minio_client.put_object.call_count == 1 + + +@pytest.mark.asyncio +async def test_publish_document_job_not_found(): + pool = AsyncMock() + pool.fetchrow.return_value = None + minio_client = MagicMock() + + ref = await publish_document_job(pool, minio_client, "missing-uuid") + assert ref == "" + assert minio_client.put_object.call_count == 0 + + +# --------------------------------------------------------------------------- +# publish_extraction_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_extraction_job(): + pool = AsyncMock() + pool.fetch.return_value = [ + { + "document_id": "doc-uuid-1", + "ticker": "AAPL", + "relevance": 0.9, + "sentiment": "positive", + "impact_score": 0.7, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "confidence": 0.85, + "novelty_score": 0.6, + "source_credibility": 0.8, + "key_facts": ["strong earnings"], + "risks": ["regulatory"], + "macro_themes": ["ai_capex"], + "model_name": "gpt-oss:20b", + "prompt_version": "document-intel-v2", + "schema_version": "2.0.0", + "extraction_at": NOW, + "company_name": "Apple Inc.", + }, + ] + minio_client = MagicMock() + + refs = await publish_extraction_job(pool, minio_client, "doc-uuid-1") + + assert len(refs) == 1 + assert refs[0].startswith("s3://stonks-lakehouse/warehouse/document_extractions/") + assert minio_client.put_object.call_count == 1 + + +@pytest.mark.asyncio +async def test_publish_extraction_job_empty(): + pool = AsyncMock() + pool.fetch.return_value = [] + minio_client = MagicMock() + + refs = await publish_extraction_job(pool, minio_client, "doc-uuid-1") + assert refs == [] + + +# --------------------------------------------------------------------------- +# publish_market_snapshot_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_market_snapshot_bar(): + pool = AsyncMock() + pool.fetchrow.return_value = { + "ticker": "AAPL", + "snapshot_type": "bar", + "data": {"open": 150.0, "high": 155.0, "low": 149.0, "close": 153.0, + "volume": 1000000, "vwap": 152.0, "trade_count": 5000}, + "source_provider": "polygon", + "captured_at": NOW, + } + minio_client = MagicMock() + + refs = await publish_market_snapshot_job(pool, minio_client, "snap-uuid-1") + + assert len(refs) == 1 + assert refs[0].startswith("s3://stonks-lakehouse/warehouse/market_bars/") + + +@pytest.mark.asyncio +async def test_publish_market_snapshot_quote(): + pool = AsyncMock() + pool.fetchrow.return_value = { + "ticker": "AAPL", + "snapshot_type": "quote", + "data": {"bid_price": 150.0, "ask_price": 150.5, "last_price": 150.25}, + "source_provider": "polygon", + "captured_at": NOW, + } + minio_client = MagicMock() + + refs = await publish_market_snapshot_job(pool, minio_client, "snap-uuid-1") + + assert len(refs) == 1 + assert refs[0].startswith("s3://stonks-lakehouse/warehouse/market_quotes/") + + +# --------------------------------------------------------------------------- +# publish_order_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_order_job(): + pool = AsyncMock() + pool.fetchrow.return_value = { + "id": "ord-uuid-1", + "recommendation_id": "rec-uuid-1", + "ticker": "AAPL", + "side": "buy", + "order_type": "market", + "quantity": 10, + "limit_price": None, + "status": "filled", + "submitted_at": NOW, + "fill_price": 150.25, + "fill_quantity": 10, + "filled_at": NOW, + "broker_account": "acct-001", + "execution_mode": "paper", + } + minio_client = MagicMock() + + ref = await publish_order_job(pool, minio_client, "ord-uuid-1") + + assert ref.startswith("s3://stonks-lakehouse/warehouse/trade_orders/") + assert minio_client.put_object.call_count == 1 + + +# --------------------------------------------------------------------------- +# publish_fills_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_fills_job(): + pool = AsyncMock() + pool.fetch.return_value = [ + { + "fill_id": "fill-uuid-1", + "order_id": "ord-uuid-1", + "data": {"fill_price": 150.25, "fill_quantity": 10, "commission": 0.5}, + "broker_timestamp": NOW, + "ticker": "AAPL", + "side": "buy", + "broker_account": "acct-001", + }, + ] + minio_client = MagicMock() + + refs = await publish_fills_job(pool, minio_client, "ord-uuid-1") + + assert len(refs) == 1 + assert refs[0].startswith("s3://stonks-lakehouse/warehouse/trade_fills/") + + +# --------------------------------------------------------------------------- +# publish_positions_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_positions_job(): + pool = AsyncMock() + pool.fetch.return_value = [ + { + "ticker": "AAPL", + "quantity": 100, + "avg_entry_price": 145.0, + "current_price": 150.0, + "unrealized_pnl": 500.0, + "realized_pnl": 0, + "broker_account": "acct-001", + "execution_mode": "paper", + }, + ] + minio_client = MagicMock() + + ref = await publish_positions_job(pool, minio_client, "acct-uuid-1") + + assert ref.startswith("s3://stonks-lakehouse/warehouse/positions_daily/") + assert minio_client.put_object.call_count == 1 + + +# --------------------------------------------------------------------------- +# publish_pnl_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_publish_pnl_job(): + pool = AsyncMock() + pool.fetch.return_value = [ + { + "ticker": "AAPL", + "quantity": 100, + "avg_entry_price": 145.0, + "current_price": 150.0, + "unrealized_pnl": 500.0, + "realized_pnl": 200.0, + "broker_account": "acct-001", + "execution_mode": "paper", + }, + ] + minio_client = MagicMock() + + refs = await publish_pnl_job(pool, minio_client, "acct-uuid-1") + + assert len(refs) == 1 + assert refs[0].startswith("s3://stonks-lakehouse/warehouse/pnl_daily/") + + +# --------------------------------------------------------------------------- +# dispatch_job +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dispatch_unknown_job_type(): + pool = AsyncMock() + minio_client = MagicMock() + + result = await dispatch_job(pool, minio_client, {"job_type": "unknown", "entity_id": "x"}) + + assert result["error"] is not None + assert "Unknown" in str(result["error"]) + + +@pytest.mark.asyncio +async def test_dispatch_document_job(): + pool = AsyncMock() + pool.fetchrow.return_value = { + "id": "doc-uuid-1", + "document_type": "article", + "source_type": "news_api", + "publisher": "Reuters", + "title": "Test", + "url": "", + "canonical_url": "", + "language": "en", + "published_at": NOW, + "retrieved_at": NOW, + "content_hash": "abc", + "parse_quality_score": 0.8, + "ticker": "AAPL", + } + minio_client = MagicMock() + + result = await dispatch_job( + pool, minio_client, + {"job_type": "document", "entity_id": "doc-uuid-1"}, + ) + + assert result["error"] is None + refs = result["refs"] + assert isinstance(refs, list) + assert len(refs) == 1 + + +@pytest.mark.asyncio +async def test_dispatch_job_handles_exception(): + pool = AsyncMock() + pool.fetchrow.side_effect = Exception("DB down") + minio_client = MagicMock() + + result = await dispatch_job( + pool, minio_client, + {"job_type": "document", "entity_id": "doc-uuid-1"}, + ) + + assert result["error"] is not None + assert "DB down" in str(result["error"]) diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..17c4751 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,136 @@ +"""Tests for structured logging and distributed tracing.""" +import json +import logging + +from services.shared.logging import ( + JSONFormatter, + Span, + extract_trace_context, + get_service_name, + get_span_id, + get_trace_id, + inject_trace_context, + new_span_id, + new_trace_id, + set_trace_context, + setup_logging, +) + + +def test_new_trace_id_format(): + tid = new_trace_id() + assert len(tid) == 16 + assert tid.isalnum() + + +def test_new_span_id_format(): + sid = new_span_id() + assert len(sid) == 8 + assert sid.isalnum() + + +def test_set_and_get_trace_context(): + set_trace_context(trace_id="abc123", span_id="sp01", service="test_svc") + assert get_trace_id() == "abc123" + assert get_span_id() == "sp01" + assert get_service_name() == "test_svc" + + +def test_json_formatter_output(): + set_trace_context(trace_id="trace42", span_id="span7", service="fmt_test") + formatter = JSONFormatter() + record = logging.LogRecord( + name="test_logger", level=logging.INFO, pathname="", lineno=0, + msg="hello world", args=(), exc_info=None, + ) + output = formatter.format(record) + parsed = json.loads(output) + assert parsed["message"] == "hello world" + assert parsed["level"] == "INFO" + assert parsed["trace_id"] == "trace42" + assert parsed["span_id"] == "span7" + assert parsed["service"] == "fmt_test" + assert "timestamp" in parsed + + +def test_json_formatter_extra_fields(): + set_trace_context(trace_id="t1", service="extra_test") + formatter = JSONFormatter() + record = logging.LogRecord( + name="test", level=logging.WARNING, pathname="", lineno=0, + msg="doc processed", args=(), exc_info=None, + ) + record.ticker = "AAPL" + record.document_id = "doc-123" + output = formatter.format(record) + parsed = json.loads(output) + assert parsed["ticker"] == "AAPL" + assert parsed["document_id"] == "doc-123" + + +def test_span_sets_and_restores_context(): + set_trace_context(trace_id="parent_trace", span_id="parent_span", service="span_test") + parent_span = get_span_id() + + with Span("test_op", ticker="MSFT") as span: + assert get_trace_id() == "parent_trace" + assert get_span_id() == span.span_id + assert span.span_id != parent_span + + # Context restored after span exits + assert get_span_id() == parent_span + + +def test_span_records_duration(): + set_trace_context(service="dur_test") + with Span("slow_op") as span: + pass # instant + assert span.duration_ms >= 0 + + +def test_span_generates_trace_id_if_missing(): + set_trace_context(trace_id="", service="gen_test") + with Span("auto_trace") as span: + assert len(span.trace_id) == 16 + + +def test_inject_trace_context(): + set_trace_context(trace_id="inject_trace") + payload = inject_trace_context({"ticker": "GOOG"}) + assert payload["_trace_id"] == "inject_trace" + assert payload["ticker"] == "GOOG" + + +def test_extract_trace_context(): + payload = {"ticker": "TSLA", "_trace_id": "extracted_trace"} + extract_trace_context(payload) + assert get_trace_id() == "extracted_trace" + + +def test_extract_trace_context_generates_new_if_missing(): + payload = {"ticker": "AMZN"} + extract_trace_context(payload) + assert len(get_trace_id()) == 16 + + +def test_setup_logging_json_mode(): + setup_logging("test_service", level="DEBUG", json_output=True) + root = logging.getLogger() + assert len(root.handlers) == 1 + assert isinstance(root.handlers[0].formatter, JSONFormatter) + assert root.level == logging.DEBUG + assert get_service_name() == "test_service" + + +def test_setup_logging_text_mode(): + setup_logging("text_service", level="WARNING", json_output=False) + root = logging.getLogger() + assert len(root.handlers) == 1 + assert not isinstance(root.handlers[0].formatter, JSONFormatter) + assert root.level == logging.WARNING + + +def test_config_json_logs_field(): + from services.shared.config import load_config + config = load_config() + assert isinstance(config.json_logs, bool) diff --git a/tests/test_market_adapter.py b/tests/test_market_adapter.py new file mode 100644 index 0000000..6c28111 --- /dev/null +++ b/tests/test_market_adapter.py @@ -0,0 +1,165 @@ +"""Tests for the Polygon.io market data adapter. + +Validates request building, response parsing, and error handling. +""" +from services.adapters.market_adapter import MarketDataAdapter, PolygonMarketAdapter + + +# --- Fake Polygon responses --- + +PREV_BARS_RESPONSE = { + "ticker": "AAPL", + "queryCount": 1, + "resultsCount": 1, + "adjusted": True, + "results": [ + { + "T": "AAPL", + "v": 58_350_544, + "vw": 171.5322, + "o": 171.0, + "c": 172.28, + "h": 173.1, + "l": 170.5, + "t": 1712793600000, + "n": 620_123, + } + ], + "status": "OK", + "request_id": "req-abc-123", +} + +TICKER_DETAILS_RESPONSE = { + "results": { + "ticker": "AAPL", + "name": "Apple Inc.", + "market": "stocks", + "locale": "us", + "primary_exchange": "XNAS", + "type": "CS", + "currency_name": "usd", + "market_cap": 2_700_000_000_000, + }, + "status": "OK", + "request_id": "req-def-456", +} + +RANGE_BARS_RESPONSE = { + "ticker": "AAPL", + "queryCount": 3, + "resultsCount": 3, + "adjusted": True, + "results": [ + {"T": "AAPL", "o": 170.0, "c": 171.0, "h": 172.0, "l": 169.5, "v": 50_000_000, "t": 1712620800000}, + {"T": "AAPL", "o": 171.0, "c": 172.0, "h": 173.0, "l": 170.0, "v": 55_000_000, "t": 1712707200000}, + {"T": "AAPL", "o": 172.0, "c": 172.5, "h": 174.0, "l": 171.0, "v": 48_000_000, "t": 1712793600000}, + ], + "status": "OK", + "request_id": "req-ghi-789", +} + + +class TestPolygonSourceType: + def test_source_type(self): + adapter = PolygonMarketAdapter(api_key="k") + assert adapter.source_type() == "market_api" + + def test_inherits_market_data_adapter(self): + assert issubclass(PolygonMarketAdapter, MarketDataAdapter) + + def test_bucket_name(self): + adapter = PolygonMarketAdapter(api_key="k") + assert adapter.bucket_name() == "stonks-raw-market" + + +class TestPolygonBuildRequest: + def setup_method(self): + self.adapter = PolygonMarketAdapter(api_key="test-key", base_url="https://api.polygon.io") + + def test_prev_bars_default(self): + url, params = self.adapter._build_request("AAPL", "prev_bars", {}) + assert url == "https://api.polygon.io/v2/aggs/ticker/AAPL/prev" + assert params["apiKey"] == "test-key" + + def test_prev_bars_with_adjusted(self): + url, params = self.adapter._build_request("AAPL", "prev_bars", {"adjusted": False}) + assert params["adjusted"] == "false" + + def test_range_bars(self): + config = { + "multiplier": 1, + "timespan": "day", + "from_date": "2026-04-01", + "to_date": "2026-04-10", + "adjusted": True, + "limit": 50, + "sort": "asc", + } + url, params = self.adapter._build_request("AAPL", "range_bars", config) + assert "/v2/aggs/ticker/AAPL/range/1/day/2026-04-01/2026-04-10" in url + assert params["adjusted"] == "true" + assert params["limit"] == "50" + assert params["sort"] == "asc" + + def test_ticker_details(self): + url, params = self.adapter._build_request("MSFT", "ticker_details", {}) + assert url == "https://api.polygon.io/v3/reference/tickers/MSFT" + assert params["apiKey"] == "test-key" + + def test_unknown_endpoint_defaults_to_prev(self): + url, _ = self.adapter._build_request("AAPL", "unknown_thing", {}) + assert "/v2/aggs/ticker/AAPL/prev" in url + + def test_trailing_slash_stripped(self): + adapter = PolygonMarketAdapter(api_key="k", base_url="https://api.polygon.io/") + url, _ = adapter._build_request("AAPL", "prev_bars", {}) + assert "//v2" not in url + + +class TestPolygonExtractItems: + def setup_method(self): + self.adapter = PolygonMarketAdapter(api_key="k") + + def test_extract_prev_bars(self): + items = self.adapter._extract_items(PREV_BARS_RESPONSE, "prev_bars") + assert len(items) == 1 + assert items[0]["T"] == "AAPL" + + def test_extract_range_bars(self): + items = self.adapter._extract_items(RANGE_BARS_RESPONSE, "range_bars") + assert len(items) == 3 + + def test_extract_ticker_details(self): + items = self.adapter._extract_items(TICKER_DETAILS_RESPONSE, "ticker_details") + assert len(items) == 1 + assert items[0]["ticker"] == "AAPL" + + def test_extract_empty_results_list(self): + items = self.adapter._extract_items({"results": [], "status": "OK"}, "prev_bars") + assert items == [] + + def test_extract_missing_results_key(self): + items = self.adapter._extract_items({"status": "OK"}, "prev_bars") + assert items == [] + + def test_extract_ticker_details_empty(self): + items = self.adapter._extract_items({"results": {}, "status": "OK"}, "ticker_details") + assert items == [] + + +class TestPolygonErrorResult: + def test_error_result_fields(self): + adapter = PolygonMarketAdapter(api_key="k") + result = adapter._error_result("AAPL", "something broke", 42.5, http_status=500, raw=b"err") + assert not result.ok + assert result.error == "something broke" + assert result.http_status == 500 + assert result.response_time_ms == 42.5 + assert result.raw_payload == b"err" + assert result.metadata["provider"] == "polygon" + + def test_error_result_defaults(self): + adapter = PolygonMarketAdapter(api_key="k") + result = adapter._error_result("AAPL", "timeout", 100.0) + assert result.http_status is None + assert result.raw_payload == b"" diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 0000000..164068b --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,139 @@ +"""Tests for metadata persistence helpers. + +Validates the helper functions in services.shared.metadata that don't +require a live database connection: type resolution, publisher extraction, +date parsing, market snapshot type inference, and retry/failure tracking +computations. + +Requirements: 3.3, 3.4, 9.2 +""" +from datetime import datetime, timezone + +from services.shared.metadata import ( + RETRY_BACKOFF_BASE, + RETRY_BACKOFF_MAX, + RETRY_MAX_COUNT, + _extract_publisher, + _infer_market_snapshot_type, + _parse_published_at, + _resolve_document_type, + compute_next_retry_at, +) + + +class TestResolveDocumentType: + def test_news_api(self): + assert _resolve_document_type("news_api") == "article" + + def test_filings_api(self): + assert _resolve_document_type("filings_api") == "filing" + + def test_web_scrape(self): + assert _resolve_document_type("web_scrape") == "press_release" + + def test_unknown_defaults_to_article(self): + assert _resolve_document_type("something_else") == "article" + + +class TestExtractPublisher: + def test_direct_publisher_field(self): + assert _extract_publisher({"publisher": "Reuters"}) == "Reuters" + + def test_source_dict_with_name(self): + assert _extract_publisher({"source": {"name": "Bloomberg"}}) == "Bloomberg" + + def test_source_string(self): + assert _extract_publisher({"source": "AP News"}) == "AP News" + + def test_empty_item(self): + assert _extract_publisher({}) == "" + + def test_publisher_takes_precedence(self): + item = {"publisher": "Reuters", "source": {"name": "Bloomberg"}} + assert _extract_publisher(item) == "Reuters" + + +class TestParsePublishedAt: + def test_iso_format_with_z(self): + result = _parse_published_at({"publishedAt": "2026-04-10T12:00:00Z"}) + assert result is not None + assert result.year == 2026 + assert result.month == 4 + + def test_iso_format_with_offset(self): + result = _parse_published_at({"published_at": "2026-04-10T12:00:00+00:00"}) + assert result is not None + + def test_none_when_missing(self): + assert _parse_published_at({}) is None + + def test_datetime_passthrough(self): + dt = datetime(2026, 1, 1, tzinfo=timezone.utc) + result = _parse_published_at({"publishedAt": dt}) + assert result is dt + + def test_invalid_string_returns_none(self): + assert _parse_published_at({"publishedAt": "not-a-date"}) is None + + +class TestInferMarketSnapshotType: + def test_bar_from_ohlc(self): + item = {"o": 100, "h": 105, "l": 99, "c": 103, "v": 1000} + assert _infer_market_snapshot_type(item) == "bar" + + def test_ticker_details_from_market_cap(self): + item = {"market_cap": 2_000_000_000, "name": "Apple"} + assert _infer_market_snapshot_type(item) == "ticker_details" + + def test_ticker_details_from_sic_code(self): + item = {"sic_code": "3674", "name": "NVIDIA"} + assert _infer_market_snapshot_type(item) == "ticker_details" + + def test_quote_from_bid_ask(self): + item = {"bid": 100.5, "ask": 101.0} + assert _infer_market_snapshot_type(item) == "quote" + + def test_generic_snapshot_fallback(self): + item = {"some_field": "value"} + assert _infer_market_snapshot_type(item) == "snapshot" + + +class TestComputeNextRetryAt: + def test_first_retry_uses_base_delay(self): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + result = compute_next_retry_at(0, now=now) + expected_seconds = RETRY_BACKOFF_BASE # 60s + delta = (result - now).total_seconds() + assert delta == expected_seconds + + def test_exponential_growth(self): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + d0 = (compute_next_retry_at(0, now=now) - now).total_seconds() + d1 = (compute_next_retry_at(1, now=now) - now).total_seconds() + d2 = (compute_next_retry_at(2, now=now) - now).total_seconds() + assert d1 == d0 * 2 + assert d2 == d1 * 2 + + def test_capped_at_max(self): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + result = compute_next_retry_at(20, now=now) + delta = (result - now).total_seconds() + assert delta == RETRY_BACKOFF_MAX + + def test_defaults_to_utc_now(self): + before = datetime.now(timezone.utc) + result = compute_next_retry_at(0) + after = datetime.now(timezone.utc) + assert before <= result + assert (result - after).total_seconds() <= RETRY_BACKOFF_BASE + 1 + + +class TestRetryConstants: + def test_max_count_is_reasonable(self): + assert RETRY_MAX_COUNT == 10 + + def test_backoff_base_is_one_minute(self): + assert RETRY_BACKOFF_BASE == 60 + + def test_backoff_max_is_one_hour(self): + assert RETRY_BACKOFF_MAX == 3600 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..a1699ca --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,151 @@ +"""Tests for Prometheus metrics definitions and instrumentation.""" +from prometheus_client import Counter, Gauge, Histogram, Info + +from services.shared.metrics import ( + ACTIVE_JOBS, + AGGREGATION_CONTRADICTION_SCORE, + AGGREGATION_DURATION, + AGGREGATION_SIGNALS_PROCESSED, + AGGREGATION_WINDOWS_COMPUTED, + ALERT_ACTIVE, + ALERT_CHECK_DURATION, + ALERTS_FIRED, + ALERTS_RESOLVED, + EXTRACTION_ATTEMPTS, + EXTRACTION_CONFIDENCE, + EXTRACTION_DURATION, + EXTRACTION_JOBS_TOTAL, + EXTRACTION_RETRIES, + EXTRACTION_TOKEN_ESTIMATE, + EXTRACTION_VALIDATION_ERRORS, + INGESTION_ADAPTER_DURATION, + INGESTION_ERRORS, + INGESTION_ITEMS_DEDUPED, + INGESTION_ITEMS_FETCHED, + INGESTION_ITEMS_NEW, + INGESTION_JOBS_TOTAL, + LAKE_FACTS_PUBLISHED, + LAKE_PUBLISH_BYTES, + LAKE_PUBLISH_DURATION, + LAKE_PUBLISH_ERRORS, + ORDERS_DUPLICATES_PREVENTED, + ORDERS_FILLED, + ORDERS_REJECTED, + ORDERS_SUBMITTED, + PARSE_DURATION, + PARSE_JOBS_TOTAL, + PARSE_LOW_QUALITY_TOTAL, + PARSE_QUALITY_SCORE, + POSITIONS_SYNCED, + RECOMMENDATION_CONFIDENCE, + RECOMMENDATION_GENERATED, + RECOMMENDATION_SUPPRESSED, + RISK_CHECK_FAILURES, + RISK_EVALUATIONS_TOTAL, + SERVICE_INFO, +) + + +def test_ingestion_metrics_are_correct_types(): + assert isinstance(INGESTION_JOBS_TOTAL, Counter) + assert isinstance(INGESTION_ITEMS_FETCHED, Counter) + assert isinstance(INGESTION_ITEMS_NEW, Counter) + assert isinstance(INGESTION_ITEMS_DEDUPED, Counter) + assert isinstance(INGESTION_ERRORS, Counter) + assert isinstance(INGESTION_ADAPTER_DURATION, Histogram) + + +def test_parse_metrics_are_correct_types(): + assert isinstance(PARSE_JOBS_TOTAL, Counter) + assert isinstance(PARSE_QUALITY_SCORE, Histogram) + assert isinstance(PARSE_LOW_QUALITY_TOTAL, Counter) + assert isinstance(PARSE_DURATION, Histogram) + + +def test_extraction_metrics_are_correct_types(): + assert isinstance(EXTRACTION_JOBS_TOTAL, Counter) + assert isinstance(EXTRACTION_ATTEMPTS, Counter) + assert isinstance(EXTRACTION_RETRIES, Counter) + assert isinstance(EXTRACTION_DURATION, Histogram) + assert isinstance(EXTRACTION_CONFIDENCE, Histogram) + assert isinstance(EXTRACTION_VALIDATION_ERRORS, Counter) + assert isinstance(EXTRACTION_TOKEN_ESTIMATE, Counter) + + +def test_aggregation_metrics_are_correct_types(): + assert isinstance(AGGREGATION_WINDOWS_COMPUTED, Counter) + assert isinstance(AGGREGATION_SIGNALS_PROCESSED, Counter) + assert isinstance(AGGREGATION_CONTRADICTION_SCORE, Histogram) + assert isinstance(AGGREGATION_DURATION, Histogram) + + +def test_recommendation_metrics_are_correct_types(): + assert isinstance(RECOMMENDATION_GENERATED, Counter) + assert isinstance(RECOMMENDATION_SUPPRESSED, Counter) + assert isinstance(RECOMMENDATION_CONFIDENCE, Histogram) + + +def test_lake_metrics_are_correct_types(): + assert isinstance(LAKE_FACTS_PUBLISHED, Counter) + assert isinstance(LAKE_PUBLISH_DURATION, Histogram) + assert isinstance(LAKE_PUBLISH_ERRORS, Counter) + assert isinstance(LAKE_PUBLISH_BYTES, Counter) + + +def test_trading_metrics_are_correct_types(): + assert isinstance(ORDERS_SUBMITTED, Counter) + assert isinstance(ORDERS_REJECTED, Counter) + assert isinstance(ORDERS_FILLED, Counter) + assert isinstance(ORDERS_DUPLICATES_PREVENTED, Counter) + assert isinstance(RISK_EVALUATIONS_TOTAL, Counter) + assert isinstance(RISK_CHECK_FAILURES, Counter) + assert isinstance(POSITIONS_SYNCED, Counter) + + +def test_active_jobs_gauge(): + assert isinstance(ACTIVE_JOBS, Gauge) + + +def test_alerting_metrics_are_correct_types(): + assert isinstance(ALERTS_FIRED, Counter) + assert isinstance(ALERTS_RESOLVED, Counter) + assert isinstance(ALERT_CHECK_DURATION, Histogram) + assert isinstance(ALERT_ACTIVE, Gauge) + + +def test_service_info(): + assert isinstance(SERVICE_INFO, Info) + + +def test_counter_labels_work(): + """Verify labeled counters can be incremented without error.""" + INGESTION_JOBS_TOTAL.labels(source_type="news_api", status="success").inc() + INGESTION_ITEMS_FETCHED.labels(source_type="market_api").inc(5) + EXTRACTION_JOBS_TOTAL.labels(status="success").inc() + AGGREGATION_WINDOWS_COMPUTED.labels(window="7d").inc() + RECOMMENDATION_GENERATED.labels(action="buy", mode="paper_eligible").inc() + LAKE_FACTS_PUBLISHED.labels(table_name="trade_signals").inc() + ORDERS_SUBMITTED.labels(side="buy", order_type="market", mode="paper").inc() + ORDERS_REJECTED.labels(reason_category="risk_engine").inc() + RISK_EVALUATIONS_TOTAL.labels(result="passed").inc() + + +def test_histogram_observe_works(): + """Verify histograms accept observations without error.""" + INGESTION_ADAPTER_DURATION.labels(source_type="news_api").observe(1.5) + PARSE_QUALITY_SCORE.observe(0.85) + PARSE_DURATION.observe(0.3) + EXTRACTION_DURATION.observe(5.2) + EXTRACTION_CONFIDENCE.observe(0.9) + AGGREGATION_CONTRADICTION_SCORE.observe(0.15) + AGGREGATION_DURATION.labels(window="7d").observe(0.8) + RECOMMENDATION_CONFIDENCE.observe(0.72) + LAKE_PUBLISH_DURATION.labels(table_name="market_bars").observe(0.05) + + +def test_metrics_endpoint_import(): + """Verify the prometheus_client generate_latest works.""" + from prometheus_client import generate_latest + output = generate_latest() + assert isinstance(output, bytes) + assert b"stonks_" in output diff --git a/tests/test_news_adapter.py b/tests/test_news_adapter.py new file mode 100644 index 0000000..c91ac62 --- /dev/null +++ b/tests/test_news_adapter.py @@ -0,0 +1,143 @@ +"""Tests for the Polygon.io news adapter. + +Validates request building, response parsing, and error handling. +""" +from services.adapters.news_adapter import NewsDataAdapter, PolygonNewsAdapter + + +# --- Fake Polygon news responses --- + +NEWS_RESPONSE = { + "results": [ + { + "id": "abc123", + "publisher": {"name": "Reuters", "homepage_url": "https://reuters.com"}, + "title": "Apple Reports Record Revenue", + "article_url": "https://reuters.com/apple-record", + "tickers": ["AAPL"], + "published_utc": "2026-04-10T14:30:00Z", + "description": "Apple Inc. reported record quarterly revenue.", + "keywords": ["earnings", "apple", "revenue"], + }, + { + "id": "def456", + "publisher": {"name": "Bloomberg", "homepage_url": "https://bloomberg.com"}, + "title": "Apple Supply Chain Update", + "article_url": "https://bloomberg.com/apple-supply", + "tickers": ["AAPL", "TSM"], + "published_utc": "2026-04-10T12:00:00Z", + "description": "Supply chain adjustments for upcoming product cycle.", + "keywords": ["supply_chain", "apple"], + }, + ], + "status": "OK", + "request_id": "req-news-001", + "count": 2, + "next_url": "https://api.polygon.io/v2/reference/news?cursor=abc", +} + +EMPTY_NEWS_RESPONSE = { + "results": [], + "status": "OK", + "request_id": "req-news-002", + "count": 0, +} + + +class TestPolygonNewsSourceType: + def test_source_type(self): + adapter = PolygonNewsAdapter(api_key="k") + assert adapter.source_type() == "news_api" + + def test_inherits_news_data_adapter(self): + assert issubclass(PolygonNewsAdapter, NewsDataAdapter) + + def test_bucket_name(self): + adapter = PolygonNewsAdapter(api_key="k") + assert adapter.bucket_name() == "stonks-raw-news" + + +class TestPolygonNewsBuildRequest: + def setup_method(self): + self.adapter = PolygonNewsAdapter(api_key="test-key", base_url="https://api.polygon.io") + + def test_default_request(self): + url, params = self.adapter._build_request("AAPL", {}) + assert url == "https://api.polygon.io/v2/reference/news" + assert params["apiKey"] == "test-key" + assert params["ticker"] == "AAPL" + assert params["limit"] == "20" + + def test_custom_limit(self): + _, params = self.adapter._build_request("AAPL", {"limit": 50}) + assert params["limit"] == "50" + + def test_limit_capped_at_1000(self): + _, params = self.adapter._build_request("AAPL", {"limit": 5000}) + assert params["limit"] == "1000" + + def test_order_param(self): + _, params = self.adapter._build_request("AAPL", {"order": "asc"}) + assert params["order"] == "asc" + + def test_date_filters(self): + config = { + "published_utc_gte": "2026-04-01", + "published_utc_lte": "2026-04-10", + } + _, params = self.adapter._build_request("AAPL", config) + assert params["published_utc.gte"] == "2026-04-01" + assert params["published_utc.lte"] == "2026-04-10" + + def test_no_date_filters_when_absent(self): + _, params = self.adapter._build_request("AAPL", {}) + assert "published_utc.gte" not in params + assert "published_utc.lte" not in params + + def test_trailing_slash_stripped(self): + adapter = PolygonNewsAdapter(api_key="k", base_url="https://api.polygon.io/") + url, _ = adapter._build_request("AAPL", {}) + assert "//v2" not in url + + +class TestPolygonNewsExtractItems: + def setup_method(self): + self.adapter = PolygonNewsAdapter(api_key="k") + + def test_extract_articles(self): + items = self.adapter._extract_items(NEWS_RESPONSE) + assert len(items) == 2 + assert items[0]["title"] == "Apple Reports Record Revenue" + assert items[1]["tickers"] == ["AAPL", "TSM"] + + def test_extract_empty_results(self): + items = self.adapter._extract_items(EMPTY_NEWS_RESPONSE) + assert items == [] + + def test_extract_missing_results_key(self): + items = self.adapter._extract_items({"status": "OK"}) + assert items == [] + + def test_extract_non_list_results(self): + items = self.adapter._extract_items({"results": "unexpected"}) + assert items == [] + + +class TestPolygonNewsErrorResult: + def test_error_result_fields(self): + adapter = PolygonNewsAdapter(api_key="k") + result = adapter._error_result("AAPL", "rate limited", 150.0, http_status=429, raw=b"slow down") + assert not result.ok + assert result.error == "rate limited" + assert result.http_status == 429 + assert result.response_time_ms == 150.0 + assert result.raw_payload == b"slow down" + assert result.metadata["provider"] == "polygon" + assert result.source_type == "news_api" + + def test_error_result_defaults(self): + adapter = PolygonNewsAdapter(api_key="k") + result = adapter._error_result("MSFT", "timeout", 200.0) + assert result.http_status is None + assert result.raw_payload == b"" + assert result.ticker == "MSFT" diff --git a/tests/test_ollama_client.py b/tests/test_ollama_client.py new file mode 100644 index 0000000..bf3552b --- /dev/null +++ b/tests/test_ollama_client.py @@ -0,0 +1,388 @@ +"""Tests for the Ollama client wrapper.""" +import json +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from services.extractor.client import ( + ExtractionResponse, + OllamaClient, + _compute_backoff, + _is_retryable, +) +from services.shared.config import OllamaConfig + + +def _valid_extraction_json() -> str: + """Minimal valid extraction result as JSON string.""" + return json.dumps({ + "summary": "Apple beat earnings expectations.", + "companies": [ + { + "ticker": "AAPL", + "company_name": "Apple Inc.", + "relevance": 0.95, + "sentiment": "positive", + "impact_score": 0.7, + "impact_horizon": "1d_30d", + "catalyst_type": "earnings", + "key_facts": ["Revenue up 12%"], + "risks": [], + "evidence_spans": ["Apple beat expectations"], + } + ], + "macro_themes": ["ai_capex"], + "novelty_score": 0.6, + "confidence": 0.85, + "extraction_warnings": [], + }) + + +def _ollama_response(content: str) -> httpx.Response: + """Build a fake Ollama /api/chat response.""" + body = {"message": {"role": "assistant", "content": content}} + return httpx.Response(200, json=body) + + +def _make_config() -> OllamaConfig: + return OllamaConfig( + base_url="http://test:11434", + model="test-model", + timeout=10, + retry_base_delay=0.0, + retry_max_delay=0.0, + retry_backoff_multiplier=2.0, + ) + + +@pytest.mark.asyncio +async def test_extract_success(): + """Successful extraction on first attempt.""" + transport = httpx.MockTransport( + lambda req: _ollama_response(_valid_extraction_json()) + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), http_client=http) + + resp = await client.extract( + document_text="Apple reported record Q4 earnings.", + document_type="article", + document_id="doc-1", + ) + + assert resp.success + assert resp.result is not None + assert resp.result.companies[0].ticker == "AAPL" + assert len(resp.attempts) == 1 + assert resp.attempts[0].error is None + assert resp.model == "test-model" + assert resp.prompt_metadata["prompt_version"] + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_retry_on_invalid_json(): + """Client retries when model returns invalid JSON, then succeeds.""" + call_count = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count == 1: + return _ollama_response("not valid json {{{") + return _ollama_response(_valid_extraction_json()) + + transport = httpx.MockTransport(handler) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=2, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert resp.success + assert len(resp.attempts) == 2 + assert resp.attempts[0].error is not None + assert resp.attempts[1].error is None + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_all_retries_exhausted(): + """All retries fail — response indicates failure with all attempts recorded.""" + transport = httpx.MockTransport( + lambda req: _ollama_response("bad output") + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=1, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert resp.result is None + assert len(resp.attempts) == 2 # initial + 1 retry + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_http_timeout(): + """HTTP timeout is captured as an error.""" + def handler(request: httpx.Request) -> httpx.Response: + raise httpx.ReadTimeout("timed out") + + transport = httpx.MockTransport(handler) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=0, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert resp.attempts[0].error == "timeout" + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_http_500(): + """HTTP 500 is captured as an error.""" + transport = httpx.MockTransport( + lambda req: httpx.Response(500, text="Internal Server Error") + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=0, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert "500" in (resp.attempts[0].error or "") + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_empty_model_response(): + """Empty content from model is treated as an error.""" + transport = httpx.MockTransport( + lambda req: _ollama_response("") + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=0, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert resp.attempts[0].error == "empty_model_response" + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_schema_validation_failure(): + """Model returns valid JSON but missing required fields.""" + bad_extraction = json.dumps({"summary": "test"}) # missing companies, etc. + transport = httpx.MockTransport( + lambda req: _ollama_response(bad_extraction) + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=0, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert resp.attempts[0].validation is not None + assert not resp.attempts[0].validation.valid + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_with_known_tickers(): + """Known tickers are passed through to the prompt builder.""" + transport = httpx.MockTransport( + lambda req: _ollama_response(_valid_extraction_json()) + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), http_client=http) + + resp = await client.extract( + document_text="test", + document_type="article", + known_tickers=["AAPL", "MSFT"], + ) + + assert resp.success + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_sends_structured_format(): + """The request payload includes the JSON schema in the format field.""" + captured_payload: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured_payload.update(json.loads(request.content)) + return _ollama_response(_valid_extraction_json()) + + transport = httpx.MockTransport(handler) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), http_client=http) + + await client.extract(document_text="test", document_type="article") + + assert "format" in captured_payload + assert isinstance(captured_payload["format"], dict) + assert captured_payload["stream"] is False + assert captured_payload["model"] == "test-model" + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_non_retryable_http_400_stops_immediately(): + """HTTP 400 is non-retryable — client stops after first attempt.""" + call_count = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + return httpx.Response(400, text="Bad Request") + + transport = httpx.MockTransport(handler) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=3, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert len(resp.attempts) == 1 # no retries for 400 + assert resp.attempts[0].error == "http_400" + assert not resp.attempts[0].retryable + assert call_count == 1 + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_retryable_http_500_retries(): + """HTTP 500 is retryable — client retries up to max_retries.""" + call_count = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count <= 2: + return httpx.Response(500, text="Internal Server Error") + return _ollama_response(_valid_extraction_json()) + + transport = httpx.MockTransport(handler) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), max_retries=3, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert resp.success + assert len(resp.attempts) == 3 + assert resp.attempts[0].retryable is True + assert resp.attempts[1].retryable is True + assert call_count == 3 + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_retryable_field_on_success(): + """Successful attempt has retryable=True (default).""" + transport = httpx.MockTransport( + lambda req: _ollama_response(_valid_extraction_json()) + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(_make_config(), http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert resp.success + assert resp.attempts[0].retryable is True + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_backoff_is_called_between_retries(): + """asyncio.sleep is called with increasing delays between retries.""" + config = OllamaConfig( + base_url="http://test:11434", + model="test-model", + timeout=10, + retry_base_delay=1.0, + retry_max_delay=10.0, + retry_backoff_multiplier=2.0, + ) + transport = httpx.MockTransport( + lambda req: _ollama_response("bad output") + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(config, max_retries=2, http_client=http) + + with patch("services.extractor.client.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert len(resp.attempts) == 3 # initial + 2 retries + assert mock_sleep.call_count == 2 + # First backoff: 1.0 * 2^0 = 1.0 + assert mock_sleep.call_args_list[0].args[0] == pytest.approx(1.0) + # Second backoff: 1.0 * 2^1 = 2.0 + assert mock_sleep.call_args_list[1].args[0] == pytest.approx(2.0) + + await client.close() + + +@pytest.mark.asyncio +async def test_extract_uses_config_max_retries(): + """Client uses max_retries from config when not overridden.""" + config = OllamaConfig( + base_url="http://test:11434", + model="test-model", + timeout=10, + max_retries=1, + retry_base_delay=0.0, + ) + transport = httpx.MockTransport( + lambda req: _ollama_response("bad output") + ) + http = httpx.AsyncClient(transport=transport) + client = OllamaClient(config, http_client=http) + + resp = await client.extract(document_text="test", document_type="article") + + assert not resp.success + assert len(resp.attempts) == 2 # initial + 1 retry from config + + await client.close() + + +def test_compute_backoff(): + """Backoff computation respects multiplier and max delay.""" + assert _compute_backoff(0, 1.0, 10.0, 2.0) == 1.0 + assert _compute_backoff(1, 1.0, 10.0, 2.0) == 2.0 + assert _compute_backoff(2, 1.0, 10.0, 2.0) == 4.0 + assert _compute_backoff(3, 1.0, 10.0, 2.0) == 8.0 + assert _compute_backoff(4, 1.0, 10.0, 2.0) == 10.0 # capped at max + + +def test_is_retryable(): + """Error classification for retry decisions.""" + assert _is_retryable("timeout") is True + assert _is_retryable("http_500") is True + assert _is_retryable("connection_error: refused") is True + assert _is_retryable("empty_model_response") is True + assert _is_retryable("invalid_response_json") is True + assert _is_retryable("http_400") is False + assert _is_retryable("http_401") is False + assert _is_retryable("http_403") is False + assert _is_retryable("http_404") is False + assert _is_retryable("http_422") is False + assert _is_retryable(None) is False diff --git a/tests/test_operator_approval.py b/tests/test_operator_approval.py new file mode 100644 index 0000000..d405a34 --- /dev/null +++ b/tests/test_operator_approval.py @@ -0,0 +1,142 @@ +"""Tests for the operator approval workflow for live trading mode. + +Validates: +- requires_approval logic for paper/live/disabled modes +- ApprovalRequest model behavior (pending, expired) +- compute_expiry calculation +- Integration with broker service process_order_job flow +""" +from datetime import datetime, timedelta, timezone + +from services.risk.approval import ( + ApprovalRequest, + ApprovalStatus, + compute_expiry, + requires_approval, +) +from services.risk.engine import ( + OperatorApproval, + PortfolioRiskConfig, + TradingMode, +) + + +# --------------------------------------------------------------------------- +# requires_approval tests +# --------------------------------------------------------------------------- + + +class TestRequiresApproval: + def test_paper_mode_auto_approved(self): + """Paper orders are auto-approved by default.""" + config = PortfolioRiskConfig(trading_mode=TradingMode.PAPER) + assert requires_approval(config) is False + + def test_paper_mode_approval_required_when_auto_approve_off(self): + """Paper orders need approval when auto_approve_paper is False.""" + config = PortfolioRiskConfig( + trading_mode=TradingMode.PAPER, + operator_approval=OperatorApproval(auto_approve_paper=False), + ) + assert requires_approval(config) is True + + def test_live_mode_requires_approval_by_default(self): + """Live orders require approval by default.""" + config = PortfolioRiskConfig(trading_mode=TradingMode.LIVE) + assert requires_approval(config) is True + + def test_live_mode_no_approval_when_disabled(self): + """Live orders skip approval when require_approval_for_live is False.""" + config = PortfolioRiskConfig( + trading_mode=TradingMode.LIVE, + operator_approval=OperatorApproval(require_approval_for_live=False), + ) + assert requires_approval(config) is False + + def test_disabled_mode_never_requires_approval(self): + """Disabled mode never requires approval (blocked upstream).""" + config = PortfolioRiskConfig(trading_mode=TradingMode.DISABLED) + assert requires_approval(config) is False + + def test_override_trading_mode_parameter(self): + """The trading_mode parameter overrides config.trading_mode.""" + config = PortfolioRiskConfig(trading_mode=TradingMode.PAPER) + # Override to live — should require approval + assert requires_approval(config, trading_mode=TradingMode.LIVE) is True + + +# --------------------------------------------------------------------------- +# compute_expiry tests +# --------------------------------------------------------------------------- + + +class TestComputeExpiry: + def test_default_30_minutes(self): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + config = PortfolioRiskConfig() + expiry = compute_expiry(config, now=now) + assert expiry == now + timedelta(minutes=30) + + def test_custom_timeout(self): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + config = PortfolioRiskConfig( + operator_approval=OperatorApproval(approval_timeout_minutes=60), + ) + expiry = compute_expiry(config, now=now) + assert expiry == now + timedelta(minutes=60) + + +# --------------------------------------------------------------------------- +# ApprovalRequest model tests +# --------------------------------------------------------------------------- + + +class TestApprovalRequest: + def test_defaults(self): + req = ApprovalRequest(ticker="AAPL") + assert req.ticker == "AAPL" + assert req.status == ApprovalStatus.PENDING + assert req.is_pending is True + assert req.approval_id # auto-generated UUID + + def test_is_expired_when_past_expiry(self): + past = datetime.now(timezone.utc) - timedelta(minutes=5) + req = ApprovalRequest(ticker="AAPL", expires_at=past) + assert req.is_expired is True + + def test_not_expired_when_future_expiry(self): + future = datetime.now(timezone.utc) + timedelta(minutes=30) + req = ApprovalRequest(ticker="AAPL", expires_at=future) + assert req.is_expired is False + + def test_approved_is_not_expired(self): + past = datetime.now(timezone.utc) - timedelta(minutes=5) + req = ApprovalRequest( + ticker="AAPL", + status=ApprovalStatus.APPROVED, + expires_at=past, + ) + assert req.is_expired is False + + def test_to_dict_roundtrip(self): + req = ApprovalRequest( + ticker="MSFT", + side="sell", + quantity=50.0, + estimated_value=15000.0, + recommendation_id="rec-123", + ) + d = req.to_dict() + assert d["ticker"] == "MSFT" + assert d["side"] == "sell" + assert d["quantity"] == 50.0 + assert d["status"] == "pending" + assert d["recommendation_id"] == "rec-123" + + def test_explicit_expired_status(self): + req = ApprovalRequest( + ticker="AAPL", + status=ApprovalStatus.EXPIRED, + ) + assert req.is_expired is True + assert req.is_pending is False diff --git a/tests/test_paper_trading.py b/tests/test_paper_trading.py new file mode 100644 index 0000000..96bfdec --- /dev/null +++ b/tests/test_paper_trading.py @@ -0,0 +1,339 @@ +"""Tests for the paper trading adapter - local order simulation and state sync. + +Validates position tracking, order fills, idempotency, cash management, +and the PaperAccount/PaperPosition data structures. +""" +import pytest + +from services.adapters.broker_adapter import ( + OrderRequest, + OrderResponse, + OrderSide, + OrderStatus, + OrderType, + PositionInfo, + TradingMode, +) +from services.adapters.paper_trading import ( + PaperAccount, + PaperPosition, + PaperTradingAdapter, +) + + +# --------------------------------------------------------------------------- +# PaperPosition tests +# --------------------------------------------------------------------------- + + +class TestPaperPosition: + def test_new_position_is_not_open(self): + pos = PaperPosition(ticker="AAPL") + assert not pos.is_open + assert pos.quantity == 0.0 + + def test_buy_fill_opens_position(self): + pos = PaperPosition(ticker="AAPL") + pnl = pos.apply_fill(OrderSide.BUY, 10, 150.0) + assert pos.is_open + assert pos.quantity == 10 + assert pos.avg_entry_price == 150.0 + assert pnl == 0.0 + + def test_sell_fill_realizes_pnl(self): + pos = PaperPosition(ticker="AAPL", quantity=10, avg_entry_price=150.0) + pnl = pos.apply_fill(OrderSide.SELL, 5, 160.0) + assert pos.quantity == 5 + assert pnl == 50.0 # 5 shares * $10 gain + assert pos.realized_pnl == 50.0 + + def test_sell_all_closes_position(self): + pos = PaperPosition(ticker="AAPL", quantity=10, avg_entry_price=150.0) + pos.apply_fill(OrderSide.SELL, 10, 140.0) + assert not pos.is_open + assert pos.quantity == 0 + assert pos.avg_entry_price == 0.0 + assert pos.realized_pnl == -100.0 # 10 * -$10 + + def test_buy_averages_up(self): + pos = PaperPosition(ticker="AAPL", quantity=10, avg_entry_price=100.0) + pos.apply_fill(OrderSide.BUY, 10, 200.0) + assert pos.quantity == 20 + assert pos.avg_entry_price == 150.0 # (1000 + 2000) / 20 + + def test_to_position_info(self): + pos = PaperPosition(ticker="AAPL", quantity=10, avg_entry_price=150.0) + info = pos.to_position_info(current_price=160.0) + assert isinstance(info, PositionInfo) + assert info.ticker == "AAPL" + assert info.quantity == 10 + assert info.unrealized_pnl == 100.0 # 10 * $10 + assert info.market_value == 1600.0 + + def test_to_position_info_no_current_price(self): + pos = PaperPosition(ticker="AAPL", quantity=10, avg_entry_price=150.0) + info = pos.to_position_info() + assert info.current_price == 150.0 + assert info.unrealized_pnl == 0.0 + + +# --------------------------------------------------------------------------- +# PaperAccount tests +# --------------------------------------------------------------------------- + + +class TestPaperAccount: + def test_default_account(self): + acct = PaperAccount() + assert acct.cash == 100_000.0 + assert acct.portfolio_value == 100_000.0 + assert acct.buying_power == 100_000.0 + + def test_custom_initial_cash(self): + acct = PaperAccount(initial_cash=50_000.0) + assert acct.cash == 50_000.0 + + def test_get_position_creates_new(self): + acct = PaperAccount() + pos = acct.get_position("AAPL") + assert pos.ticker == "AAPL" + assert pos.quantity == 0 + + def test_get_position_returns_existing(self): + acct = PaperAccount() + pos1 = acct.get_position("AAPL") + pos1.quantity = 10 + pos2 = acct.get_position("AAPL") + assert pos2.quantity == 10 + + def test_portfolio_value_includes_positions(self): + acct = PaperAccount(initial_cash=10_000.0) + acct.cash = 5_000.0 + pos = acct.get_position("AAPL") + pos.quantity = 10 + pos.avg_entry_price = 100.0 + # portfolio = cash + position value = 5000 + 1000 = 6000 + assert acct.portfolio_value == 6_000.0 + + def test_to_account_info(self): + acct = PaperAccount(account_id="test-acct") + info = acct.to_account_info() + assert info.account_id == "test-acct" + assert info.mode == TradingMode.PAPER + assert info.cash == 100_000.0 + + +# --------------------------------------------------------------------------- +# PaperTradingAdapter tests +# --------------------------------------------------------------------------- + + +class TestPaperTradingAdapterBasics: + def test_mode_is_paper(self): + adapter = PaperTradingAdapter() + assert adapter.mode == TradingMode.PAPER + + def test_source_type(self): + adapter = PaperTradingAdapter() + assert adapter.source_type() == "broker" + + def test_custom_initial_cash(self): + adapter = PaperTradingAdapter(initial_cash=50_000.0) + assert adapter.account.cash == 50_000.0 + + +@pytest.mark.asyncio +class TestPaperTradingSubmitOrder: + async def test_buy_market_order_fills(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + order_type=OrderType.LIMIT, + limit_price=150.0, + ) + resp = await adapter.submit_order(order) + assert resp.status == OrderStatus.FILLED + assert resp.filled_quantity == 10 + assert resp.filled_avg_price == 150.0 + assert resp.ok + # Cash should decrease + assert adapter.account.cash < 100_000.0 + + async def test_sell_order_realizes_pnl(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + # Buy first + buy = OrderRequest(ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=150.0) + await adapter.submit_order(buy) + + # Sell at higher price + sell = OrderRequest(ticker="AAPL", side=OrderSide.SELL, quantity=10, + order_type=OrderType.LIMIT, limit_price=160.0) + resp = await adapter.submit_order(sell) + assert resp.status == OrderStatus.FILLED + assert resp.raw_response["realized_pnl"] == 100.0 # 10 * $10 + + async def test_insufficient_cash_rejects(self): + adapter = PaperTradingAdapter(initial_cash=1_000.0) + order = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=100, + order_type=OrderType.LIMIT, + limit_price=150.0, + ) + resp = await adapter.submit_order(order) + assert resp.status == OrderStatus.REJECTED + assert "Insufficient cash" in resp.error + + async def test_insufficient_shares_rejects(self): + adapter = PaperTradingAdapter() + order = OrderRequest( + ticker="AAPL", + side=OrderSide.SELL, + quantity=10, + order_type=OrderType.LIMIT, + limit_price=150.0, + ) + resp = await adapter.submit_order(order) + assert resp.status == OrderStatus.REJECTED + assert "Insufficient shares" in resp.error + + async def test_idempotency_returns_same_response(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", + side=OrderSide.BUY, + quantity=10, + order_type=OrderType.LIMIT, + limit_price=150.0, + idempotency_key="test-key-1", + ) + resp1 = await adapter.submit_order(order) + resp2 = await adapter.submit_order(order) + assert resp1.broker_order_id == resp2.broker_order_id + assert resp1.status == resp2.status + # Cash should only be deducted once + assert adapter.account.cash == pytest.approx(100_000.0 - 1500.0) + + async def test_order_events_recorded(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=5, + order_type=OrderType.LIMIT, limit_price=100.0, + ) + await adapter.submit_order(order) + events = adapter.account.order_events + event_types = [e["event_type"] for e in events] + assert "submitted" in event_types + assert "accepted" in event_types + assert "fill" in event_types + + async def test_stop_order_fills_at_stop_price(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.STOP, stop_price=145.0, + ) + resp = await adapter.submit_order(order) + assert resp.filled_avg_price == 145.0 + + +@pytest.mark.asyncio +class TestPaperTradingCancelAndStatus: + async def test_cancel_filled_order_rejected(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=5, + order_type=OrderType.LIMIT, limit_price=100.0, + ) + resp = await adapter.submit_order(order) + cancel_resp = await adapter.cancel_order(resp.broker_order_id) + assert cancel_resp.status == OrderStatus.REJECTED + assert "filled" in cancel_resp.error.lower() + + async def test_cancel_unknown_order(self): + adapter = PaperTradingAdapter() + resp = await adapter.cancel_order("nonexistent-id") + assert resp.status == OrderStatus.REJECTED + assert "not found" in resp.error + + async def test_get_order_status(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=5, + order_type=OrderType.LIMIT, limit_price=100.0, + ) + resp = await adapter.submit_order(order) + status = await adapter.get_order_status(resp.broker_order_id) + assert status.status == OrderStatus.FILLED + + async def test_get_unknown_order_status(self): + adapter = PaperTradingAdapter() + resp = await adapter.get_order_status("nonexistent") + assert resp.status == OrderStatus.REJECTED + + +@pytest.mark.asyncio +class TestPaperTradingPositionsAndAccount: + async def test_get_positions_empty(self): + adapter = PaperTradingAdapter() + positions = await adapter.get_positions() + assert positions == [] + + async def test_get_positions_after_buy(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=150.0, + ) + await adapter.submit_order(order) + positions = await adapter.get_positions() + assert len(positions) == 1 + assert positions[0].ticker == "AAPL" + assert positions[0].quantity == 10 + + async def test_get_account(self): + adapter = PaperTradingAdapter(initial_cash=50_000.0, account_id="test") + info = await adapter.get_account() + assert info.account_id == "test" + assert info.cash == 50_000.0 + assert info.mode == TradingMode.PAPER + + +@pytest.mark.asyncio +class TestPaperTradingFetch: + async def test_fetch_positions(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + buy = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=5, + order_type=OrderType.LIMIT, limit_price=100.0, + ) + await adapter.submit_order(buy) + result = await adapter.fetch("AAPL", {"endpoint": "positions"}) + assert result.ok + assert len(result.items) == 1 + assert result.metadata["provider"] == "paper" + + async def test_fetch_account(self): + adapter = PaperTradingAdapter() + result = await adapter.fetch("*", {"endpoint": "account"}) + assert result.ok + assert result.items[0]["mode"] == "paper" + + async def test_fetch_orders(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + buy = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=5, + order_type=OrderType.LIMIT, limit_price=100.0, + ) + await adapter.submit_order(buy) + result = await adapter.fetch("AAPL", {"endpoint": "orders"}) + assert len(result.items) == 1 + + async def test_fetch_empty_position(self): + adapter = PaperTradingAdapter() + result = await adapter.fetch("AAPL", {"endpoint": "positions"}) + assert len(result.items) == 0 diff --git a/tests/test_paper_trading_simulation.py b/tests/test_paper_trading_simulation.py new file mode 100644 index 0000000..9037805 --- /dev/null +++ b/tests/test_paper_trading_simulation.py @@ -0,0 +1,627 @@ +"""Paper trading simulation scenarios. + +End-to-end scenarios that exercise the full recommendation-to-execution +pipeline through the paper trading adapter, risk engine, and position +tracking. Each scenario simulates a realistic trading session using +real logic from all service modules — no mocked business logic. + +Covers: +- Single-symbol buy-and-sell round trips with P&L verification +- Multi-symbol portfolio construction and diversification +- Risk engine rejection scenarios (position limits, daily loss, lockouts) +- Idempotent order submission under replay conditions +- Insufficient funds and insufficient shares edge cases +- Recommendation-driven order flow (bullish → buy, bearish → sell) +- Portfolio drawdown halting via daily loss limits +- News-shock lockout preventing trades during high-impact events + +Requirements: 7.1-7.4, 8.1-8.5 +""" +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import pytest + +from services.adapters.broker_adapter import ( + OrderRequest, + OrderSide, + OrderStatus, + OrderType, + TradingMode, +) +from services.adapters.paper_trading import PaperTradingAdapter +from services.aggregation.worker import ( + ImpactRow, + assemble_trend_with_evidence, + build_weighted_signals, +) +from services.recommendation.eligibility import evaluate_eligibility +from services.recommendation.worker import build_recommendation +from services.risk.engine import ( + AccountRiskState, + DailyLossLimits, + NewsShockLockout, + PortfolioRiskConfig, + PositionLimits, + ProposedOrder, + RiskCheckResult, + SectorExposureLimits, + SymbolCooldown, + evaluate_order, +) +from services.shared.schemas import ( + ActionType, + RecommendationMode, +) + +NOW = datetime(2026, 4, 11, 14, 0, 0, tzinfo=timezone.utc) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _bullish_impacts(ticker: str, count: int = 3) -> list[ImpactRow]: + """Generate bullish impact rows for aggregation.""" + return [ + ImpactRow( + document_id=f"doc-bull-{ticker}-{i}", + confidence=0.80 + i * 0.02, + novelty_score=0.6, + source_credibility=0.8, + sentiment="positive", + impact_score=0.70 + i * 0.03, + catalyst_type="earnings", + key_facts=[f"Strong Q{i+1} results for {ticker}"], + risks=[], + published_at=NOW - timedelta(hours=i + 1), + ) + for i in range(count) + ] + + +def _bearish_impacts(ticker: str, count: int = 3) -> list[ImpactRow]: + """Generate bearish impact rows for aggregation.""" + return [ + ImpactRow( + document_id=f"doc-bear-{ticker}-{i}", + confidence=0.78 + i * 0.02, + novelty_score=0.55, + source_credibility=0.75, + sentiment="negative", + impact_score=0.65 + i * 0.03, + catalyst_type="legal", + key_facts=[f"Regulatory action against {ticker}"], + risks=[f"Potential fine for {ticker}"], + published_at=NOW - timedelta(hours=i + 1), + ) + for i in range(count) + ] + + +def _build_trend_and_recommendation(impacts, ticker, window="7d"): + """Run aggregation + eligibility + recommendation for a set of impacts.""" + signals = build_weighted_signals(impacts, NOW, window) + assembled = assemble_trend_with_evidence( + ticker, window, signals, impacts, reference_time=NOW, + ) + summary = assembled.summary + eligibility = evaluate_eligibility(summary) + rec = build_recommendation(summary, eligibility, reference_time=NOW) + return summary, eligibility, rec + + +def _risk_state_from_adapter(adapter: PaperTradingAdapter) -> AccountRiskState: + """Build an AccountRiskState snapshot from the paper adapter's in-memory state.""" + acct = adapter.account + positions_by_symbol = { + t: p.quantity * p.avg_entry_price + for t, p in acct.positions.items() + if p.is_open + } + return AccountRiskState( + account_id=acct.account_id, + portfolio_value=acct.portfolio_value, + cash=acct.cash, + buying_power=acct.buying_power, + positions_by_symbol=positions_by_symbol, + open_position_count=sum(1 for p in acct.positions.values() if p.is_open), + ) + + +# --------------------------------------------------------------------------- +# Scenario 1: Single-symbol buy-sell round trip +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestSingleSymbolRoundTrip: + """Buy shares, sell at a profit, verify P&L and cash reconciliation.""" + + async def test_buy_hold_sell_profit(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + + # Generate bullish recommendation + impacts = _bullish_impacts("AAPL") + summary, eligibility, rec = _build_trend_and_recommendation(impacts, "AAPL") + assert rec.action == ActionType.BUY + + # Execute buy + buy = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=50, + order_type=OrderType.LIMIT, limit_price=180.0, + ) + buy_resp = await adapter.submit_order(buy) + assert buy_resp.status == OrderStatus.FILLED + assert adapter.account.cash == pytest.approx(100_000.0 - 50 * 180.0) + + # Verify position + positions = await adapter.get_positions() + assert len(positions) == 1 + assert positions[0].ticker == "AAPL" + assert positions[0].quantity == 50 + + # Sell at higher price + sell = OrderRequest( + ticker="AAPL", side=OrderSide.SELL, quantity=50, + order_type=OrderType.LIMIT, limit_price=195.0, + ) + sell_resp = await adapter.submit_order(sell) + assert sell_resp.status == OrderStatus.FILLED + assert sell_resp.raw_response["realized_pnl"] == pytest.approx(50 * 15.0) + + # Cash should be back to initial + profit + expected_cash = 100_000.0 + 50 * 15.0 + assert adapter.account.cash == pytest.approx(expected_cash) + + # Position should be closed + positions = await adapter.get_positions() + assert len(positions) == 0 + + async def test_buy_hold_sell_loss(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + + buy = OrderRequest( + ticker="TSLA", side=OrderSide.BUY, quantity=20, + order_type=OrderType.LIMIT, limit_price=250.0, + ) + await adapter.submit_order(buy) + + sell = OrderRequest( + ticker="TSLA", side=OrderSide.SELL, quantity=20, + order_type=OrderType.LIMIT, limit_price=230.0, + ) + sell_resp = await adapter.submit_order(sell) + assert sell_resp.raw_response["realized_pnl"] == pytest.approx(-400.0) + + expected_cash = 100_000.0 - 400.0 + assert adapter.account.cash == pytest.approx(expected_cash) + + +# --------------------------------------------------------------------------- +# Scenario 2: Multi-symbol portfolio construction +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestMultiSymbolPortfolio: + """Build a diversified portfolio across multiple symbols.""" + + async def test_build_three_position_portfolio(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + + orders = [ + ("AAPL", 20, 180.0), + ("MSFT", 15, 420.0), + ("GOOGL", 10, 175.0), + ] + total_cost = 0.0 + for ticker, qty, price in orders: + req = OrderRequest( + ticker=ticker, side=OrderSide.BUY, quantity=qty, + order_type=OrderType.LIMIT, limit_price=price, + ) + resp = await adapter.submit_order(req) + assert resp.status == OrderStatus.FILLED + total_cost += qty * price + + assert adapter.account.cash == pytest.approx(100_000.0 - total_cost) + + positions = await adapter.get_positions() + tickers = {p.ticker for p in positions} + assert tickers == {"AAPL", "MSFT", "GOOGL"} + + # Portfolio value = cash + position value at entry + assert adapter.account.portfolio_value == pytest.approx(100_000.0) + + async def test_partial_liquidation(self): + adapter = PaperTradingAdapter(initial_cash=50_000.0) + + # Buy two positions + await adapter.submit_order(OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=30, + order_type=OrderType.LIMIT, limit_price=150.0, + )) + await adapter.submit_order(OrderRequest( + ticker="MSFT", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=400.0, + )) + + # Sell only AAPL + await adapter.submit_order(OrderRequest( + ticker="AAPL", side=OrderSide.SELL, quantity=30, + order_type=OrderType.LIMIT, limit_price=155.0, + )) + + positions = await adapter.get_positions() + assert len(positions) == 1 + assert positions[0].ticker == "MSFT" + + +# --------------------------------------------------------------------------- +# Scenario 3: Risk engine blocks unsafe orders +# --------------------------------------------------------------------------- + + +class TestRiskEngineBlocking: + """Verify risk engine prevents orders that violate configured limits.""" + + def test_position_size_limit_blocks_large_order(self): + config = PortfolioRiskConfig( + position_limits=PositionLimits(max_position_value=5_000.0), + ) + state = AccountRiskState( + portfolio_value=100_000.0, cash=100_000.0, + ) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=10_000.0, quantity=50, + ) + result = evaluate_order(order, config, state) + assert not result.passed + assert any( + c.check_name == "max_position_value" and c.result == RiskCheckResult.FAIL + for c in result.checks + ) + + def test_sector_concentration_blocks_overweight(self): + config = PortfolioRiskConfig( + sector_exposure=SectorExposureLimits(max_sector_pct=0.20), + ) + state = AccountRiskState( + portfolio_value=100_000.0, + positions_by_sector={"Technology": 18_000.0}, + ) + order = ProposedOrder( + ticker="NVDA", sector="Technology", + estimated_value=5_000.0, quantity=20, + ) + result = evaluate_order(order, config, state) + assert not result.passed + + def test_daily_loss_halt_blocks_further_trading(self): + config = PortfolioRiskConfig( + daily_loss=DailyLossLimits( + max_daily_loss_pct=0.02, + max_daily_loss_value=2_000.0, + ), + ) + state = AccountRiskState( + portfolio_value=100_000.0, + daily_pnl=-2_500.0, + ) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=1_000.0, quantity=5, + ) + result = evaluate_order(order, config, state) + assert not result.passed + loss_failures = [ + c for c in result.checks + if c.check_name.startswith("daily_loss") and c.result == RiskCheckResult.FAIL + ] + assert len(loss_failures) >= 1 + + def test_news_shock_lockout_blocks_trade(self): + lockout_expiry = NOW + timedelta(minutes=45) + config = PortfolioRiskConfig( + news_shock=NewsShockLockout(enabled=True, lockout_minutes=60), + ) + state = AccountRiskState( + portfolio_value=100_000.0, + locked_symbols={"AAPL": lockout_expiry}, + ) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=1_000.0, quantity=5, + ) + result = evaluate_order(order, config, state, now=NOW) + assert not result.passed + assert any( + c.check_name == "news_shock_lockout" and c.result == RiskCheckResult.FAIL + for c in result.checks + ) + + def test_symbol_cooldown_blocks_rapid_retrade(self): + last_trade = NOW - timedelta(minutes=5) + config = PortfolioRiskConfig( + symbol_cooldown=SymbolCooldown(cooldown_minutes=15), + ) + state = AccountRiskState( + portfolio_value=100_000.0, + last_trade_times={"AAPL": last_trade}, + ) + order = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=1_000.0, quantity=5, + ) + result = evaluate_order(order, config, state, now=NOW) + assert not result.passed + + +# --------------------------------------------------------------------------- +# Scenario 4: Recommendation-driven order flow +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestRecommendationDrivenOrders: + """Simulate the full path: signals → recommendation → risk check → paper fill.""" + + async def test_bullish_recommendation_to_paper_buy(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + impacts = _bullish_impacts("AAPL", count=4) + summary, eligibility, rec = _build_trend_and_recommendation(impacts, "AAPL") + + assert rec.action == ActionType.BUY + assert rec.confidence > 0 + + # Risk check the proposed order + risk_state = _risk_state_from_adapter(adapter) + proposed = ProposedOrder( + ticker="AAPL", sector="Technology", + estimated_value=rec.position_sizing.portfolio_pct * risk_state.portfolio_value, + quantity=10, + confidence=rec.confidence, + recommendation_id=rec.recommendation_id, + ) + risk_eval = evaluate_order(proposed, PortfolioRiskConfig(), risk_state) + assert risk_eval.passed + + # Execute the paper order + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=180.0, + ) + resp = await adapter.submit_order(order) + assert resp.status == OrderStatus.FILLED + + async def test_bearish_recommendation_to_paper_sell(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + + # First buy a position to sell + await adapter.submit_order(OrderRequest( + ticker="TSLA", side=OrderSide.BUY, quantity=20, + order_type=OrderType.LIMIT, limit_price=250.0, + )) + + # Generate bearish recommendation + impacts = _bearish_impacts("TSLA", count=3) + summary, eligibility, rec = _build_trend_and_recommendation(impacts, "TSLA") + assert rec.action == ActionType.SELL + + # Execute the sell + sell = OrderRequest( + ticker="TSLA", side=OrderSide.SELL, quantity=20, + order_type=OrderType.LIMIT, limit_price=240.0, + ) + resp = await adapter.submit_order(sell) + assert resp.status == OrderStatus.FILLED + assert resp.raw_response["realized_pnl"] == pytest.approx(-200.0) + + async def test_low_confidence_recommendation_is_informational(self): + """Low-confidence signals should produce informational-only recommendations.""" + impacts = [ + ImpactRow( + document_id="doc-weak-1", + confidence=0.40, + novelty_score=0.3, + source_credibility=0.5, + sentiment="positive", + impact_score=0.3, + catalyst_type="other", + key_facts=["Minor update"], + risks=[], + published_at=NOW - timedelta(hours=1), + ), + ImpactRow( + document_id="doc-weak-2", + confidence=0.35, + novelty_score=0.2, + source_credibility=0.4, + sentiment="positive", + impact_score=0.25, + catalyst_type="other", + key_facts=["Routine filing"], + risks=[], + published_at=NOW - timedelta(hours=3), + ), + ] + _, _, rec = _build_trend_and_recommendation(impacts, "XYZ") + assert rec.mode == RecommendationMode.INFORMATIONAL + + +# --------------------------------------------------------------------------- +# Scenario 5: Idempotent order submission +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestIdempotentOrderSubmission: + """Verify duplicate orders with the same idempotency key are not double-executed.""" + + async def test_duplicate_buy_only_fills_once(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=150.0, + idempotency_key="idem-buy-1", + ) + + resp1 = await adapter.submit_order(order) + resp2 = await adapter.submit_order(order) + + assert resp1.broker_order_id == resp2.broker_order_id + # Cash deducted only once + assert adapter.account.cash == pytest.approx(100_000.0 - 1_500.0) + # Only one position entry + pos = adapter.account.get_position("AAPL") + assert pos.quantity == 10 + + +# --------------------------------------------------------------------------- +# Scenario 6: Insufficient funds and shares +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestInsufficientResources: + """Verify the adapter rejects orders when resources are insufficient.""" + + async def test_buy_exceeding_cash_rejected(self): + adapter = PaperTradingAdapter(initial_cash=5_000.0) + order = OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=100, + order_type=OrderType.LIMIT, limit_price=180.0, + ) + resp = await adapter.submit_order(order) + assert resp.status == OrderStatus.REJECTED + assert resp.error is not None and "Insufficient cash" in resp.error + + async def test_sell_more_than_held_rejected(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + await adapter.submit_order(OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=150.0, + )) + sell = OrderRequest( + ticker="AAPL", side=OrderSide.SELL, quantity=20, + order_type=OrderType.LIMIT, limit_price=155.0, + ) + resp = await adapter.submit_order(sell) + assert resp.status == OrderStatus.REJECTED + assert resp.error is not None and "Insufficient shares" in resp.error + + +# --------------------------------------------------------------------------- +# Scenario 7: Portfolio drawdown halts trading +# --------------------------------------------------------------------------- + + +class TestDrawdownHalt: + """Simulate a losing session that triggers the daily loss circuit breaker.""" + + def test_cumulative_losses_trigger_halt(self): + """After multiple losing trades, the risk engine should block new orders.""" + config = PortfolioRiskConfig( + daily_loss=DailyLossLimits( + max_daily_loss_pct=0.03, + max_daily_loss_value=3_000.0, + max_daily_trades=50, + ), + ) + + # Simulate state after several losing trades + state = AccountRiskState( + portfolio_value=97_000.0, + cash=47_000.0, + daily_pnl=-3_200.0, + daily_trade_count=8, + ) + + order = ProposedOrder( + ticker="NVDA", sector="Technology", + estimated_value=2_000.0, quantity=5, + ) + result = evaluate_order(order, config, state) + assert not result.passed + # Both pct and value limits should be breached + failed_checks = { + c.check_name for c in result.checks if c.result == RiskCheckResult.FAIL + } + assert "daily_loss_value" in failed_checks + + +# --------------------------------------------------------------------------- +# Scenario 8: Full session simulation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestFullTradingSession: + """Simulate a realistic multi-trade session with mixed outcomes.""" + + async def test_morning_session_with_mixed_results(self): + adapter = PaperTradingAdapter(initial_cash=100_000.0) + initial_cash = 100_000.0 + + # Trade 1: Buy AAPL, sell at profit + await adapter.submit_order(OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=30, + order_type=OrderType.LIMIT, limit_price=180.0, + )) + resp1 = await adapter.submit_order(OrderRequest( + ticker="AAPL", side=OrderSide.SELL, quantity=30, + order_type=OrderType.LIMIT, limit_price=185.0, + )) + pnl_1 = resp1.raw_response["realized_pnl"] + assert pnl_1 == pytest.approx(150.0) + + # Trade 2: Buy TSLA, sell at loss + await adapter.submit_order(OrderRequest( + ticker="TSLA", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=250.0, + )) + resp2 = await adapter.submit_order(OrderRequest( + ticker="TSLA", side=OrderSide.SELL, quantity=10, + order_type=OrderType.LIMIT, limit_price=242.0, + )) + pnl_2 = resp2.raw_response["realized_pnl"] + assert pnl_2 == pytest.approx(-80.0) + + # Trade 3: Buy MSFT, hold (don't sell) + await adapter.submit_order(OrderRequest( + ticker="MSFT", side=OrderSide.BUY, quantity=5, + order_type=OrderType.LIMIT, limit_price=420.0, + )) + + # Verify final state + positions = await adapter.get_positions() + assert len(positions) == 1 + assert positions[0].ticker == "MSFT" + + # Cash = initial + AAPL profit + TSLA loss - MSFT cost + expected_cash = initial_cash + 150.0 - 80.0 - (5 * 420.0) + assert adapter.account.cash == pytest.approx(expected_cash) + + # Audit trail should have events for all trades + event_count = len(adapter.account.order_events) + # 5 orders × 3 events each (submitted, accepted, fill) = 15 + # (rejected orders get fewer events, but all 5 here are fills) + assert event_count == 15 + + async def test_account_info_reflects_session(self): + adapter = PaperTradingAdapter(initial_cash=50_000.0, account_id="sim-session") + + await adapter.submit_order(OrderRequest( + ticker="AAPL", side=OrderSide.BUY, quantity=10, + order_type=OrderType.LIMIT, limit_price=180.0, + )) + + acct = await adapter.get_account() + assert acct.account_id == "sim-session" + assert acct.mode == TradingMode.PAPER + assert acct.cash == pytest.approx(50_000.0 - 1_800.0) + assert acct.portfolio_value == pytest.approx(50_000.0) diff --git a/tests/test_parser_worker.py b/tests/test_parser_worker.py new file mode 100644 index 0000000..0eb6c16 --- /dev/null +++ b/tests/test_parser_worker.py @@ -0,0 +1,80 @@ +"""Tests for parser worker helper functions. + +Validates build_parser_output_json produces the expected structure +from ParsedDocument and mention data. + +Requirements: 4.1, 4.2, 4.3, 9.1 +""" +from services.parser.html_parser import ParsedDocument, QualitySignals +from services.parser.worker import build_parser_output_json + + +class TestBuildParserOutputJson: + def test_includes_all_metadata_fields(self): + parsed = ParsedDocument( + body_text="Apple reported strong earnings.", + title="Apple Earnings", + author="Jane Reporter", + publisher="TechNews", + published_at="2026-04-10T14:00:00Z", + canonical_url="https://technews.example.com/apple", + language="en", + description="Apple Q2 results.", + document_type="article", + word_count=5, + outbound_links=["https://other.com/analysis"], + tags=["apple", "earnings"], + quality_score=0.75, + confidence="high", + low_quality_flag=False, + quality_warnings=[], + quality_signals=QualitySignals( + word_count_signal=0.8, + diversity_signal=0.9, + sentence_signal=1.0, + paragraph_signal=0.5, + body_found_signal=1.0, + metadata_signal=1.0, + ), + ) + mentions = [ + {"company_id": "1", "ticker": "AAPL", "mention_type": "ticker", "confidence": 0.9, "match_count": 2}, + ] + result = build_parser_output_json(parsed, mentions) + + assert result["title"] == "Apple Earnings" + assert result["author"] == "Jane Reporter" + assert result["publisher"] == "TechNews" + assert result["published_at"] == "2026-04-10T14:00:00Z" + assert result["canonical_url"] == "https://technews.example.com/apple" + assert result["language"] == "en" + assert result["description"] == "Apple Q2 results." + assert result["document_type"] == "article" + assert result["word_count"] == 5 + assert result["outbound_links"] == ["https://other.com/analysis"] + assert result["tags"] == ["apple", "earnings"] + assert result["quality_score"] == 0.75 + assert result["confidence"] == "high" + assert result["low_quality_flag"] is False + assert result["quality_warnings"] == [] + assert result["mentioned_companies"] == mentions + + def test_quality_signals_serialized(self): + parsed = ParsedDocument( + quality_signals=QualitySignals( + word_count_signal=0.3, + diversity_signal=0.5, + ), + ) + result = build_parser_output_json(parsed, []) + signals = result["quality_signals"] + assert signals["word_count"] == 0.3 + assert signals["diversity"] == 0.5 + + def test_empty_parsed_document(self): + parsed = ParsedDocument() + result = build_parser_output_json(parsed, []) + assert result["title"] == "" + assert "body_text" not in result # body text stored separately in MinIO + assert result["mentioned_companies"] == [] + assert result["confidence"] == "low" diff --git a/tests/test_query_api.py b/tests/test_query_api.py new file mode 100644 index 0000000..6f403ce --- /dev/null +++ b/tests/test_query_api.py @@ -0,0 +1,105 @@ +"""Tests for the Query API app structure and helper functions.""" +import json +from datetime import datetime, timezone + +import pytest + +from services.api.app import _parse_jsonb, _row_to_dict, app + + +# --- _parse_jsonb --- + +def test_parse_jsonb_dict(): + assert _parse_jsonb({"a": 1}) == {"a": 1} + + +def test_parse_jsonb_list(): + assert _parse_jsonb([1, 2]) == [1, 2] + + +def test_parse_jsonb_string(): + assert _parse_jsonb('{"x": 1}') == {"x": 1} + + +def test_parse_jsonb_list_string(): + assert _parse_jsonb('["a", "b"]') == ["a", "b"] + + +def test_parse_jsonb_none(): + assert _parse_jsonb(None) is None + + +def test_parse_jsonb_invalid_string(): + assert _parse_jsonb("not json") == "not json" + + +# --- _row_to_dict --- + +class FakeRecord(dict): + """Mimics asyncpg.Record enough for _row_to_dict.""" + def items(self): + return super().items() + + +def test_row_to_dict_converts_datetime(): + dt = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + row = FakeRecord({"created_at": dt, "name": "test"}) + result = _row_to_dict(row) + assert result["created_at"] == dt.isoformat() + assert result["name"] == "test" + + +def test_row_to_dict_passes_primitives(): + row = FakeRecord({"count": 42, "active": True, "label": "ok", "val": None}) + result = _row_to_dict(row) + assert result == {"count": 42, "active": True, "label": "ok", "val": None} + + +# --- App structure --- + +def test_app_has_expected_routes(): + paths = [route.path for route in app.routes] + assert "/health" in paths + assert "/api/companies" in paths + assert "/api/companies/{company_id}" in paths + assert "/api/documents" in paths + assert "/api/documents/{document_id}" in paths + assert "/api/trends" in paths + assert "/api/trends/{trend_id}" in paths + assert "/api/trends/{trend_id}/evidence" in paths + assert "/api/recommendations" in paths + assert "/api/recommendations/{recommendation_id}" in paths + assert "/api/recommendations/{recommendation_id}/evidence" in paths + assert "/api/orders" in paths + assert "/api/orders/{order_id}" in paths + assert "/api/positions" in paths + assert "/api/audit/{entity_type}/{entity_id}" in paths + + +def test_app_has_admin_routes(): + paths = [route.path for route in app.routes] + # Source health + assert "/api/admin/sources/health" in paths + assert "/api/admin/sources/{source_id}/runs" in paths + assert "/api/admin/sources/{source_id}/toggle" in paths + assert "/api/admin/sources/{source_id}/credibility" in paths + # Symbol configs + assert "/api/admin/companies/{company_id}/toggle" in paths + assert "/api/admin/companies/{company_id}/sector" in paths + assert "/api/admin/companies/coverage" in paths + # Trading mode + assert "/api/admin/trading/config" in paths + assert "/api/admin/trading/mode" in paths + assert "/api/admin/trading/approvals" in paths + assert "/api/admin/trading/approvals/{approval_id}" in paths + assert "/api/admin/trading/lockouts" in paths + + +def test_app_has_ops_dashboard_routes(): + paths = [route.path for route in app.routes] + assert "/api/ops/ingestion/throughput" in paths + assert "/api/ops/ingestion/summary" in paths + assert "/api/ops/model/failures" in paths + assert "/api/ops/model/performance" in paths + assert "/api/ops/pipeline/health" in paths + assert "/api/ops/sources/coverage-gaps" in paths diff --git a/tests/test_recommendation_eligibility.py b/tests/test_recommendation_eligibility.py new file mode 100644 index 0000000..9b01553 --- /dev/null +++ b/tests/test_recommendation_eligibility.py @@ -0,0 +1,283 @@ +"""Tests for deterministic recommendation eligibility logic.""" +from typing import Any + +from services.recommendation.eligibility import ( + DEFAULT_ELIGIBILITY_CONFIG, + EligibilityConfig, + RejectionReason, + evaluate_eligibility, +) +from services.shared.schemas import ( + ActionType, + RecommendationMode, + TrendDirection, + TrendSummary, + TrendWindow, +) + + +def _make_summary(**overrides: Any) -> TrendSummary: + """Build a TrendSummary with sensible defaults for testing.""" + defaults = dict( + entity_type="company", + entity_id="AAPL", + window=TrendWindow.SEVEN_DAY, + trend_direction=TrendDirection.BULLISH, + trend_strength=0.5, + confidence=0.6, + top_supporting_evidence=["doc1", "doc2", "doc3"], + top_opposing_evidence=[], + dominant_catalysts=["earnings"], + material_risks=["regulatory scrutiny"], + contradiction_score=0.1, + ) + defaults.update(overrides) + return TrendSummary(**defaults) + + +# --------------------------------------------------------------------------- +# Gate checks +# --------------------------------------------------------------------------- + + +def test_eligible_strong_bullish(): + """A strong bullish trend with good confidence passes all gates.""" + summary = _make_summary( + trend_strength=0.5, confidence=0.6, contradiction_score=0.1, + ) + result = evaluate_eligibility(summary) + assert result.eligible is True + assert result.rejection_reasons == [] + assert result.action == ActionType.BUY + + +def test_rejected_low_confidence(): + """Below min_confidence → rejected.""" + summary = _make_summary(confidence=0.2) + result = evaluate_eligibility(summary) + assert result.eligible is False + assert RejectionReason.LOW_CONFIDENCE in result.rejection_reasons + + +def test_rejected_low_strength(): + """Below min_trend_strength → rejected.""" + summary = _make_summary(trend_strength=0.05) + result = evaluate_eligibility(summary) + assert result.eligible is False + assert RejectionReason.LOW_TREND_STRENGTH in result.rejection_reasons + + +def test_rejected_high_contradiction(): + """Above max_contradiction_score → rejected.""" + summary = _make_summary(contradiction_score=0.7) + result = evaluate_eligibility(summary) + assert result.eligible is False + assert RejectionReason.HIGH_CONTRADICTION in result.rejection_reasons + + +def test_rejected_insufficient_evidence(): + """Too few evidence documents → rejected.""" + summary = _make_summary( + top_supporting_evidence=["doc1"], + top_opposing_evidence=[], + ) + result = evaluate_eligibility(summary) + assert result.eligible is False + assert RejectionReason.INSUFFICIENT_EVIDENCE in result.rejection_reasons + + +def test_rejected_neutral_direction(): + """Neutral trend direction → rejected.""" + summary = _make_summary(trend_direction=TrendDirection.NEUTRAL) + result = evaluate_eligibility(summary) + assert result.eligible is False + assert RejectionReason.NEUTRAL_DIRECTION in result.rejection_reasons + + +def test_rejected_forces_informational_mode(): + """Any rejection forces mode to informational (Req 7.4).""" + summary = _make_summary(confidence=0.2) + result = evaluate_eligibility(summary) + assert result.eligible is False + assert result.mode == RecommendationMode.INFORMATIONAL + + +# --------------------------------------------------------------------------- +# Action mapping +# --------------------------------------------------------------------------- + + +def test_action_buy_strong_bullish(): + summary = _make_summary( + trend_direction=TrendDirection.BULLISH, trend_strength=0.4, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.BUY + + +def test_action_sell_strong_bearish(): + summary = _make_summary( + trend_direction=TrendDirection.BEARISH, trend_strength=0.4, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.SELL + + +def test_action_hold_weak_bullish_decent_confidence(): + """Weak bullish with decent confidence → HOLD.""" + summary = _make_summary( + trend_direction=TrendDirection.BULLISH, + trend_strength=0.15, + confidence=0.55, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.HOLD + + +def test_action_watch_weak_bullish_low_confidence(): + """Weak bullish with low confidence → WATCH.""" + summary = _make_summary( + trend_direction=TrendDirection.BULLISH, + trend_strength=0.15, + confidence=0.40, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.WATCH + + +def test_action_watch_mixed(): + summary = _make_summary(trend_direction=TrendDirection.MIXED) + result = evaluate_eligibility(summary) + assert result.action == ActionType.WATCH + + +# --------------------------------------------------------------------------- +# Mode escalation +# --------------------------------------------------------------------------- + + +def test_mode_informational_for_hold(): + """HOLD actions are always informational.""" + summary = _make_summary( + trend_direction=TrendDirection.BULLISH, + trend_strength=0.15, + confidence=0.55, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.HOLD + assert result.mode == RecommendationMode.INFORMATIONAL + + +def test_mode_paper_eligible(): + """BUY with confidence >= paper threshold → paper_eligible.""" + summary = _make_summary( + trend_strength=0.4, confidence=0.55, contradiction_score=0.1, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.BUY + assert result.mode == RecommendationMode.PAPER_ELIGIBLE + + +def test_mode_live_eligible(): + """BUY with high confidence, low contradiction, enough evidence → live_eligible.""" + summary = _make_summary( + trend_strength=0.5, + confidence=0.75, + contradiction_score=0.1, + top_supporting_evidence=["d1", "d2", "d3", "d4"], + top_opposing_evidence=["d5"], + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.BUY + assert result.mode == RecommendationMode.LIVE_ELIGIBLE + + +def test_mode_not_live_high_contradiction(): + """High contradiction blocks live even with high confidence.""" + summary = _make_summary( + trend_strength=0.5, + confidence=0.75, + contradiction_score=0.4, + top_supporting_evidence=["d1", "d2", "d3", "d4", "d5"], + top_opposing_evidence=[], + ) + result = evaluate_eligibility(summary) + assert result.mode != RecommendationMode.LIVE_ELIGIBLE + + +def test_mode_informational_low_confidence_buy(): + """BUY with confidence below paper threshold → informational.""" + summary = _make_summary( + trend_strength=0.4, confidence=0.40, + ) + result = evaluate_eligibility(summary) + assert result.action == ActionType.BUY + assert result.mode == RecommendationMode.INFORMATIONAL + + +# --------------------------------------------------------------------------- +# Position sizing +# --------------------------------------------------------------------------- + + +def test_position_sizing_scales_with_confidence(): + """Higher confidence → larger portfolio allocation.""" + low = _make_summary(confidence=0.40, trend_strength=0.4) + high = _make_summary(confidence=0.80, trend_strength=0.4) + r_low = evaluate_eligibility(low) + r_high = evaluate_eligibility(high) + assert r_high.position_sizing.portfolio_pct > r_low.position_sizing.portfolio_pct + + +def test_position_sizing_penalised_by_contradiction(): + """Higher contradiction → smaller portfolio allocation.""" + clean = _make_summary(contradiction_score=0.05, trend_strength=0.4) + messy = _make_summary(contradiction_score=0.50, trend_strength=0.4) + r_clean = evaluate_eligibility(clean) + r_messy = evaluate_eligibility(messy) + assert r_clean.position_sizing.portfolio_pct > r_messy.position_sizing.portfolio_pct + + +def test_position_sizing_within_bounds(): + """Sizing should always stay within configured bounds.""" + cfg = DEFAULT_ELIGIBILITY_CONFIG + for conf in [0.35, 0.5, 0.7, 0.9]: + for contra in [0.0, 0.3, 0.55]: + summary = _make_summary(confidence=conf, contradiction_score=contra, trend_strength=0.4) + result = evaluate_eligibility(summary) + assert result.position_sizing.portfolio_pct >= cfg.base_portfolio_pct * 0.5 + assert result.position_sizing.portfolio_pct <= cfg.max_portfolio_pct + assert result.position_sizing.max_loss_pct >= cfg.base_max_loss_pct * 0.5 + assert result.position_sizing.max_loss_pct <= cfg.max_max_loss_pct + + +# --------------------------------------------------------------------------- +# Time horizon and invalidation +# --------------------------------------------------------------------------- + + +def test_time_horizon_mapped(): + summary = _make_summary(window=TrendWindow.SEVEN_DAY) + result = evaluate_eligibility(summary) + assert result.time_horizon == "swing_1d_10d" + + +def test_invalidation_conditions_present(): + summary = _make_summary() + result = evaluate_eligibility(summary) + assert len(result.invalidation_conditions) > 0 + assert any("AAPL" in c for c in result.invalidation_conditions) + + +# --------------------------------------------------------------------------- +# Custom config +# --------------------------------------------------------------------------- + + +def test_custom_config_stricter_gates(): + """A stricter config rejects what the default would accept.""" + strict = EligibilityConfig(min_confidence=0.80) + summary = _make_summary(confidence=0.60) + result = evaluate_eligibility(summary, config=strict) + assert result.eligible is False + assert RejectionReason.LOW_CONFIDENCE in result.rejection_reasons diff --git a/tests/test_recommendation_worker.py b/tests/test_recommendation_worker.py new file mode 100644 index 0000000..b396b0e --- /dev/null +++ b/tests/test_recommendation_worker.py @@ -0,0 +1,283 @@ +"""Tests for recommendation worker — generating recommendations from trend data. + +Tests the pure logic functions (no DB required). Async DB functions +are covered by integration tests. +""" +from datetime import datetime, timezone + +from services.recommendation.eligibility import evaluate_eligibility +from services.recommendation.worker import ( + _extract_risk_classification, + build_recommendation, + build_thesis, + classify_risk, +) +from services.shared.schemas import ( + ActionType, + RecommendationMode, + TrendDirection, + TrendSummary, + TrendWindow, +) + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + +def _make_summary( + ticker: str = "AAPL", + direction: TrendDirection = TrendDirection.BULLISH, + strength: float = 0.5, + confidence: float = 0.65, + contradiction: float = 0.1, + supporting: list[str] | None = None, + opposing: list[str] | None = None, + catalysts: list[str] | None = None, + risks: list[str] | None = None, + window: TrendWindow = TrendWindow.SEVEN_DAY, +) -> TrendSummary: + return TrendSummary( + entity_type="company", + entity_id=ticker, + window=window, + trend_direction=direction, + trend_strength=strength, + confidence=confidence, + top_supporting_evidence=supporting or ["doc1", "doc2", "doc3"], + top_opposing_evidence=opposing or [], + dominant_catalysts=catalysts or ["earnings"], + material_risks=risks or ["regulatory scrutiny"], + contradiction_score=contradiction, + generated_at=NOW, + ) + + +# --------------------------------------------------------------------------- +# build_thesis +# --------------------------------------------------------------------------- + + +def test_thesis_contains_ticker_and_direction(): + summary = _make_summary() + result = evaluate_eligibility(summary) + thesis = build_thesis(summary, result) + assert "AAPL" in thesis + assert "bullish" in thesis + + +def test_thesis_includes_catalysts(): + summary = _make_summary(catalysts=["product", "m_and_a"]) + result = evaluate_eligibility(summary) + thesis = build_thesis(summary, result) + assert "product" in thesis + + +def test_thesis_includes_contradiction_note(): + summary = _make_summary(contradiction=0.3) + result = evaluate_eligibility(summary) + thesis = build_thesis(summary, result) + assert "disagreement" in thesis + + +def test_thesis_includes_risks(): + summary = _make_summary(risks=["supply chain disruption"]) + result = evaluate_eligibility(summary) + thesis = build_thesis(summary, result) + assert "supply chain disruption" in thesis + + +def test_thesis_includes_evidence_counts(): + summary = _make_summary( + supporting=["d1", "d2"], + opposing=["d3"], + ) + result = evaluate_eligibility(summary) + thesis = build_thesis(summary, result) + assert "2 supporting" in thesis + assert "1 opposing" in thesis + + +def test_thesis_includes_action(): + summary = _make_summary() + result = evaluate_eligibility(summary) + thesis = build_thesis(summary, result) + assert "BUY" in thesis + + +# --------------------------------------------------------------------------- +# classify_risk +# --------------------------------------------------------------------------- + + +def test_risk_low_for_strong_signal(): + summary = _make_summary( + confidence=0.8, + contradiction=0.05, + supporting=["d1", "d2", "d3", "d4", "d5"], + ) + result = evaluate_eligibility(summary) + assert classify_risk(summary, result) == "low" + + +def test_risk_high_for_weak_signal(): + summary = _make_summary( + confidence=0.36, + contradiction=0.55, + supporting=["d1"], + opposing=[], + ) + result = evaluate_eligibility(summary) + risk = classify_risk(summary, result) + assert risk in ("high", "very_high") + + +def test_risk_moderate_for_mixed(): + summary = _make_summary( + confidence=0.5, + contradiction=0.2, + supporting=["d1", "d2"], + opposing=["d3"], + ) + result = evaluate_eligibility(summary) + assert classify_risk(summary, result) == "moderate" + + +# --------------------------------------------------------------------------- +# build_recommendation +# --------------------------------------------------------------------------- + + +def test_build_recommendation_basic(): + summary = _make_summary() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result, reference_time=NOW) + + assert rec.ticker == "AAPL" + assert rec.action == ActionType.BUY + assert rec.confidence == summary.confidence + assert rec.time_horizon == "swing_1d_10d" + assert rec.generated_at == NOW + assert len(rec.evidence_refs) == 3 # 3 supporting + 0 opposing + assert rec.model_metadata.provider == "deterministic" + + +def test_build_recommendation_includes_risk_in_thesis(): + summary = _make_summary() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert rec.thesis.startswith("[risk:") + + +def test_build_recommendation_with_llm_thesis(): + """When llm_thesis is provided, it replaces the deterministic body.""" + summary = _make_summary() + result = evaluate_eligibility(summary) + llm_text = "Apple exhibits a bullish posture driven by strong earnings." + rec = build_recommendation(summary, result, llm_thesis=llm_text) + assert llm_text in rec.thesis + assert rec.thesis.startswith("[risk:") + assert rec.model_metadata.provider == "ollama" + assert rec.model_metadata.model_name == "thesis-rewrite" + + +def test_build_recommendation_without_llm_thesis_uses_deterministic(): + """When llm_thesis is None, the deterministic thesis is used.""" + summary = _make_summary() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert rec.model_metadata.provider == "deterministic" + assert rec.model_metadata.model_name == "eligibility-v1" + + +def test_build_recommendation_combines_evidence(): + summary = _make_summary( + supporting=["s1", "s2"], + opposing=["o1"], + ) + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert rec.evidence_refs == ["s1", "s2", "o1"] + + +def test_build_recommendation_position_sizing(): + summary = _make_summary(confidence=0.7) + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert rec.position_sizing.portfolio_pct == result.position_sizing.portfolio_pct + assert rec.position_sizing.max_loss_pct == result.position_sizing.max_loss_pct + + +def test_build_recommendation_invalidation_conditions(): + summary = _make_summary() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert len(rec.invalidation_conditions) > 0 + + +def test_build_recommendation_ineligible_is_informational(): + """When eligibility fails, mode should be informational (Req 7.4).""" + summary = _make_summary(confidence=0.2) + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert rec.mode == RecommendationMode.INFORMATIONAL + + +def test_build_recommendation_sell_action(): + summary = _make_summary(direction=TrendDirection.BEARISH, strength=0.5) + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result) + assert rec.action == ActionType.SELL + assert "SELL" in rec.thesis + + +# --------------------------------------------------------------------------- +# _extract_risk_classification +# --------------------------------------------------------------------------- + + +def test_extract_risk_classification_from_thesis(): + assert _extract_risk_classification("[risk:low] Some thesis text") == "low" + assert _extract_risk_classification("[risk:very_high] Bad signal") == "very_high" + + +def test_extract_risk_classification_missing_prefix(): + assert _extract_risk_classification("No risk prefix here") == "moderate" + + +def test_extract_risk_classification_empty(): + assert _extract_risk_classification("") == "moderate" + + +# --------------------------------------------------------------------------- +# build_recommendation stores full model metadata +# --------------------------------------------------------------------------- + + +def test_build_recommendation_model_metadata_deterministic(): + summary = _make_summary() + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result, reference_time=NOW) + assert rec.model_metadata.provider == "deterministic" + assert rec.model_metadata.model_name == "eligibility-v1" + assert rec.model_metadata.schema_version == "1.0.0" + + +def test_build_recommendation_model_metadata_llm(): + summary = _make_summary() + result = evaluate_eligibility(summary) + rec = build_recommendation( + summary, result, reference_time=NOW, + llm_thesis="Rewritten thesis text.", + ) + assert rec.model_metadata.provider == "ollama" + assert rec.model_metadata.model_name == "thesis-rewrite" + assert rec.model_metadata.prompt_version != "" + + +def test_build_recommendation_risk_classification_in_thesis(): + """The risk classification should be embedded in the thesis prefix.""" + summary = _make_summary(confidence=0.8, contradiction=0.05, + supporting=["d1", "d2", "d3", "d4", "d5"]) + result = evaluate_eligibility(summary) + rec = build_recommendation(summary, result, reference_time=NOW) + risk = _extract_risk_classification(rec.thesis) + assert risk == classify_risk(summary, result) diff --git a/tests/test_replay_extraction.py b/tests/test_replay_extraction.py new file mode 100644 index 0000000..203a43c --- /dev/null +++ b/tests/test_replay_extraction.py @@ -0,0 +1,208 @@ +"""Replay dataset tests for deterministic extraction validation. + +Loads archived document fixtures and validates that their expected +extraction outputs still pass the current schema and semantic checks. +This catches schema regressions, prompt contract changes, and +validation rule drift without requiring a live Ollama instance. + +Requirements: 5.1, 5.2, 5.3, 5.4, 5.5 +""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from services.extractor.replay import ( + FIXTURES_DIR, + compare_extraction, + load_all_fixtures, + load_fixture, + validate_all_fixtures, + validate_fixture, +) +from services.extractor.schemas import ( + ExtractionResult, + get_schema_version, + validate_extraction, +) + + +# --------------------------------------------------------------------------- +# Fixture loading +# --------------------------------------------------------------------------- + +FIXTURE_DIR = FIXTURES_DIR + + +def _fixture_paths() -> list[Path]: + """Collect all .json fixture files.""" + if not FIXTURE_DIR.is_dir(): + return [] + return sorted(FIXTURE_DIR.glob("*.json")) + + +def test_fixtures_directory_exists(): + """The replay fixtures directory exists and contains JSON files.""" + assert FIXTURE_DIR.is_dir(), f"Missing fixtures dir: {FIXTURE_DIR}" + paths = _fixture_paths() + assert len(paths) >= 3, f"Expected at least 3 fixtures, found {len(paths)}" + + +def test_load_all_fixtures(): + """All fixture files load without errors.""" + fixtures = load_all_fixtures() + assert len(fixtures) >= 3 + for f in fixtures: + assert f.document_id + assert f.document_text + assert f.expected_extraction + + +def test_fixture_ids_unique(): + """Every fixture has a unique document_id.""" + fixtures = load_all_fixtures() + ids = [f.document_id for f in fixtures] + assert len(ids) == len(set(ids)), f"Duplicate fixture IDs: {ids}" + + +# --------------------------------------------------------------------------- +# Schema validation — the core deterministic test +# --------------------------------------------------------------------------- + +def test_all_expected_extractions_pass_schema(): + """Every fixture's expected_extraction passes current schema validation. + + This is the primary regression gate. If a fixture fails here, either + the fixture needs updating or the schema change is breaking. + """ + results = validate_all_fixtures() + assert len(results) >= 3 + + failures = [r for r in results if not r.schema_valid] + if failures: + msgs = [] + for f in failures: + errs = f.validation_report.errors if f.validation_report else [f.error or "unknown"] + msgs.append(f" {f.fixture_id}: {errs}") + pytest.fail( + f"{len(failures)} fixture(s) failed schema validation:\n" + "\n".join(msgs) + ) + + +@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem) +def test_individual_fixture_schema_valid(fixture_path: Path): + """Each fixture individually passes schema and semantic validation.""" + fixture = load_fixture(fixture_path) + result = validate_fixture(fixture) + assert result.schema_valid, ( + f"Fixture {fixture.document_id} failed: " + f"{result.validation_report.errors if result.validation_report else result.error}" + ) + assert result.schema_version == get_schema_version() + + +# --------------------------------------------------------------------------- +# Expected extraction structural checks +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem) +def test_expected_extraction_roundtrips(fixture_path: Path): + """Expected extraction can be parsed into ExtractionResult and back.""" + fixture = load_fixture(fixture_path) + parsed = fixture.expected_result + dumped = parsed.model_dump(mode="json") + reparsed = ExtractionResult.model_validate(dumped) + assert reparsed.summary == parsed.summary + assert len(reparsed.companies) == len(parsed.companies) + + +def test_low_quality_fixture_has_empty_companies(): + """The low-quality garbled fixture should have no companies.""" + fixtures = load_all_fixtures() + low_q = [f for f in fixtures if "low-quality" in f.document_id] + assert len(low_q) == 1 + fixture = low_q[0] + assert len(fixture.expected_result.companies) == 0 + assert fixture.expected_result.confidence <= 0.3 + + +def test_multi_company_fixture_has_multiple_tickers(): + """The multi-company fixture should reference multiple companies.""" + fixtures = load_all_fixtures() + multi = [f for f in fixtures if "multi-company" in f.document_id] + assert len(multi) == 1 + fixture = multi[0] + tickers = [c.ticker for c in fixture.expected_result.companies] + assert len(tickers) >= 3 + + +# --------------------------------------------------------------------------- +# Evidence grounding checks +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem) +def test_evidence_spans_grounded_in_document(fixture_path: Path): + """Evidence spans in expected extractions appear in the document text.""" + fixture = load_fixture(fixture_path) + report = validate_extraction( + fixture.expected_extraction, + document_text=fixture.document_text, + ) + grounding_warnings = [ + w for w in report.warnings if "evidence_span_not_found" in w + ] + assert not grounding_warnings, ( + f"Fixture {fixture.document_id} has ungrounded evidence: {grounding_warnings}" + ) + + +# --------------------------------------------------------------------------- +# Comparison logic tests (using synthetic data, no Ollama needed) +# --------------------------------------------------------------------------- + +def test_compare_extraction_perfect_match(): + """Comparison reports match when actual equals expected.""" + fixtures = load_all_fixtures() + fixture = fixtures[0] + actual = fixture.expected_result # identical + result = compare_extraction(fixture, actual) + assert result.companies_match + assert result.sentiment_match + assert result.catalyst_match + assert result.actual_schema_valid + + +def test_compare_extraction_company_mismatch(): + """Comparison detects when actual has different companies.""" + fixtures = load_all_fixtures() + # Pick a fixture with companies + fixture = [f for f in fixtures if f.expected_result.companies][0] + # Build an actual result with no companies + actual = ExtractionResult( + summary="Different", + companies=[], + macro_themes=[], + novelty_score=0.5, + confidence=0.5, + extraction_warnings=[], + ) + result = compare_extraction(fixture, actual) + assert not result.companies_match + assert any("company_mismatch" in w for w in result.warnings) + + +def test_compare_extraction_sentiment_mismatch(): + """Comparison detects sentiment drift.""" + fixtures = load_all_fixtures() + fixture = [f for f in fixtures if f.expected_result.companies][0] + # Clone expected but flip sentiment + actual_data = fixture.expected_extraction.copy() + actual_data = {**actual_data} + companies = [dict(c) for c in actual_data["companies"]] + companies[0]["sentiment"] = "negative" if companies[0]["sentiment"] != "negative" else "positive" + actual_data["companies"] = companies + actual = ExtractionResult.model_validate(actual_data) + result = compare_extraction(fixture, actual) + assert result.companies_match # same tickers + assert not result.sentiment_match # different sentiment diff --git a/tests/test_resilient_adapter.py b/tests/test_resilient_adapter.py new file mode 100644 index 0000000..f3702ea --- /dev/null +++ b/tests/test_resilient_adapter.py @@ -0,0 +1,214 @@ +"""Tests for the resilient adapter wrapper. + +Validates retry logic, backoff computation, rate-limit coordination, +and retryable error classification. +""" +from datetime import datetime, timezone +from typing import Any + +import pytest + +from services.adapters.base import AdapterResult, BaseAdapter +from services.adapters.resilient import ( + ResilientAdapter, + RetryConfig, + compute_delay, +) + + +# --- Helpers --- + + +def _make_result( + ok: bool = True, + error: str | None = None, + http_status: int | None = None, + metadata: dict[str, Any] | None = None, +) -> AdapterResult: + return AdapterResult( + source_type="market_api", + ticker="AAPL", + items=[{"price": 150}] if ok else [], + raw_payload=b'{"ok":true}' if ok else b"", + content_hash="abc" if ok else "", + fetched_at=datetime.now(timezone.utc), + error=error, + http_status=http_status, + metadata=metadata or {}, + ) + + +class FakeAdapter(BaseAdapter): + """Adapter that returns a sequence of pre-configured results.""" + + def __init__(self, results: list[AdapterResult]) -> None: + self._results = list(results) + self._call_count = 0 + + @property + def call_count(self) -> int: + return self._call_count + + async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult: + idx = min(self._call_count, len(self._results) - 1) + self._call_count += 1 + return self._results[idx] + + def source_type(self) -> str: + return "market_api" + + +# --- Tests --- + + +class TestComputeDelay: + def test_first_attempt_is_base_delay_plus_jitter(self): + cfg = RetryConfig(base_delay=1.0, max_delay=60.0, jitter_factor=0.0) + delay = compute_delay(0, cfg) + assert delay == pytest.approx(1.0, abs=0.01) + + def test_exponential_growth(self): + cfg = RetryConfig(base_delay=1.0, max_delay=60.0, jitter_factor=0.0) + d0 = compute_delay(0, cfg) + d1 = compute_delay(1, cfg) + d2 = compute_delay(2, cfg) + assert d1 == pytest.approx(2.0, abs=0.01) + assert d2 == pytest.approx(4.0, abs=0.01) + assert d2 > d1 > d0 + + def test_capped_at_max_delay(self): + cfg = RetryConfig(base_delay=1.0, max_delay=10.0, jitter_factor=0.0) + delay = compute_delay(10, cfg) + assert delay <= 10.0 + + def test_jitter_adds_randomness(self): + cfg = RetryConfig(base_delay=1.0, max_delay=60.0, jitter_factor=1.0) + delays = {compute_delay(0, cfg) for _ in range(20)} + # With jitter_factor=1.0, we should see some variation + assert len(delays) > 1 + + +class TestRetryableClassification: + def setup_method(self) -> None: + adapter = FakeAdapter([_make_result()]) + self.resilient = ResilientAdapter(adapter) + + def test_ok_result_not_retryable(self): + result = _make_result(ok=True) + assert self.resilient._is_retryable(result) is False + + def test_429_is_retryable(self): + result = _make_result(ok=False, error="rate limited", http_status=429) + assert self.resilient._is_retryable(result) is True + + def test_500_is_retryable(self): + result = _make_result(ok=False, error="server error", http_status=500) + assert self.resilient._is_retryable(result) is True + + def test_503_is_retryable(self): + result = _make_result(ok=False, error="unavailable", http_status=503) + assert self.resilient._is_retryable(result) is True + + def test_400_not_retryable(self): + result = _make_result(ok=False, error="bad request", http_status=400) + assert self.resilient._is_retryable(result) is False + + def test_401_not_retryable(self): + result = _make_result(ok=False, error="unauthorized", http_status=401) + assert self.resilient._is_retryable(result) is False + + def test_timeout_error_retryable(self): + result = _make_result(ok=False, error="timeout: read timed out") + assert self.resilient._is_retryable(result) is True + + def test_connection_error_retryable(self): + result = _make_result(ok=False, error="Connection refused") + assert self.resilient._is_retryable(result) is True + + def test_generic_error_not_retryable(self): + result = _make_result(ok=False, error="invalid JSON response") + assert self.resilient._is_retryable(result) is False + + + +@pytest.mark.asyncio +class TestResilientFetch: + async def test_success_on_first_try(self): + adapter = FakeAdapter([_make_result(ok=True)]) + resilient = ResilientAdapter( + adapter, retry_config=RetryConfig(max_retries=2, base_delay=0.01) + ) + result = await resilient.fetch("AAPL", {}) + assert result.ok + assert adapter.call_count == 1 + assert result.metadata["retry_stats"]["attempts"] == 1 + + async def test_retries_on_retryable_then_succeeds(self): + results = [ + _make_result(ok=False, error="server error", http_status=500), + _make_result(ok=False, error="server error", http_status=500), + _make_result(ok=True), + ] + adapter = FakeAdapter(results) + resilient = ResilientAdapter( + adapter, retry_config=RetryConfig(max_retries=3, base_delay=0.01) + ) + result = await resilient.fetch("AAPL", {}) + assert result.ok + assert adapter.call_count == 3 + assert result.metadata["retry_stats"]["attempts"] == 3 + + async def test_exhausts_retries(self): + fail = _make_result(ok=False, error="server error", http_status=500) + adapter = FakeAdapter([fail, fail, fail, fail]) + resilient = ResilientAdapter( + adapter, retry_config=RetryConfig(max_retries=2, base_delay=0.01) + ) + result = await resilient.fetch("AAPL", {}) + assert not result.ok + assert adapter.call_count == 3 # initial + 2 retries + assert result.metadata["retry_stats"]["exhausted"] is True + + async def test_no_retry_on_non_retryable(self): + fail = _make_result(ok=False, error="bad request", http_status=400) + adapter = FakeAdapter([fail]) + resilient = ResilientAdapter( + adapter, retry_config=RetryConfig(max_retries=3, base_delay=0.01) + ) + result = await resilient.fetch("AAPL", {}) + assert not result.ok + assert adapter.call_count == 1 + + async def test_retry_after_respected_for_429(self): + fail_429 = _make_result( + ok=False, error="rate limited", http_status=429, + metadata={"retry_after": 0.05}, + ) + results = [fail_429, _make_result(ok=True)] + adapter = FakeAdapter(results) + resilient = ResilientAdapter( + adapter, retry_config=RetryConfig(max_retries=2, base_delay=0.01) + ) + result = await resilient.fetch("AAPL", {}) + assert result.ok + assert adapter.call_count == 2 + # Should have waited at least the retry_after amount + assert result.metadata["retry_stats"]["total_delay"] >= 0.04 + + async def test_source_type_passthrough(self): + adapter = FakeAdapter([_make_result()]) + resilient = ResilientAdapter(adapter) + assert resilient.source_type() == "market_api" + + async def test_default_config_for_known_source_type(self): + adapter = FakeAdapter([_make_result()]) + resilient = ResilientAdapter(adapter) + # market_api default is 30 rate limit max + assert resilient.config.rate_limit_max == 30 + + async def test_custom_config_overrides_default(self): + adapter = FakeAdapter([_make_result()]) + custom = RetryConfig(max_retries=5, rate_limit_max=100) + resilient = ResilientAdapter(adapter, retry_config=custom) + assert resilient.config.max_retries == 5 + assert resilient.config.rate_limit_max == 100 diff --git a/tests/test_retention.py b/tests/test_retention.py new file mode 100644 index 0000000..40db4eb --- /dev/null +++ b/tests/test_retention.py @@ -0,0 +1,172 @@ +"""Tests for data retention and lifecycle controls. + +Validates retention policy resolution, expired object detection, +cleanup logic, and DB record cleanup. + +Requirements: N3 +""" +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock + +from services.shared.config import RetentionConfig +from services.shared.retention import ( + RetentionPolicy, + cleanup_bucket, + cutoff_date, + default_retention_days, + delete_expired_objects, + list_expired_objects, + merge_policies, + resolve_policies, +) + + +class TestDefaultRetentionDays: + def test_known_buckets(self): + config = RetentionConfig() + assert default_retention_days("stonks-raw-market", config) == 90 + assert default_retention_days("stonks-raw-news", config) == 180 + assert default_retention_days("stonks-raw-filings", config) == 365 + assert default_retention_days("stonks-lakehouse", config) == 730 + assert default_retention_days("stonks-audit", config) == 730 + + def test_unknown_bucket_defaults_to_365(self): + config = RetentionConfig() + assert default_retention_days("unknown-bucket", config) == 365 + + def test_custom_config_values(self): + config = RetentionConfig(raw_market_days=30, audit_days=1000) + assert default_retention_days("stonks-raw-market", config) == 30 + assert default_retention_days("stonks-audit", config) == 1000 + + +class TestResolvePolicies: + def test_returns_policy_per_bucket(self): + config = RetentionConfig() + policies = resolve_policies(config) + bucket_names = [p.bucket_name for p in policies] + assert "stonks-raw-market" in bucket_names + assert "stonks-lakehouse" in bucket_names + assert len(policies) == 8 + + def test_uses_config_values(self): + config = RetentionConfig(raw_news_days=60) + policies = resolve_policies(config) + news_policy = next(p for p in policies if p.bucket_name == "stonks-raw-news") + assert news_policy.retention_days == 60 + + +class TestMergePolicies: + def test_db_overrides_config(self): + config_policies = [ + RetentionPolicy("stonks-raw-market", 90), + RetentionPolicy("stonks-raw-news", 180), + ] + db_policies = { + "stonks-raw-market": RetentionPolicy("stonks-raw-market", 30), + } + merged = merge_policies(config_policies, db_policies) + market = next(p for p in merged if p.bucket_name == "stonks-raw-market") + news = next(p for p in merged if p.bucket_name == "stonks-raw-news") + assert market.retention_days == 30 # DB override + assert news.retention_days == 180 # config default + + def test_empty_db_uses_all_config(self): + config_policies = [RetentionPolicy("stonks-audit", 730)] + merged = merge_policies(config_policies, {}) + assert len(merged) == 1 + assert merged[0].retention_days == 730 + + +class TestCutoffDate: + def test_calculates_cutoff(self): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + cutoff = cutoff_date(90, now) + expected = now - timedelta(days=90) + assert cutoff == expected + + def test_uses_current_time_when_none(self): + cutoff = cutoff_date(30) + assert cutoff < datetime.now(timezone.utc) + + +class TestListExpiredObjects: + def test_finds_expired_objects(self): + client = MagicMock() + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + old_obj = MagicMock() + old_obj.object_name = "old/file.json" + old_obj.last_modified = now - timedelta(days=100) + + new_obj = MagicMock() + new_obj.object_name = "new/file.json" + new_obj.last_modified = now - timedelta(days=10) + + client.list_objects.return_value = [old_obj, new_obj] + + expired = list_expired_objects(client, "stonks-raw-market", 90, now=now) + assert expired == ["old/file.json"] + + def test_respects_batch_size(self): + client = MagicMock() + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + objects = [] + for i in range(10): + obj = MagicMock() + obj.object_name = f"file_{i}.json" + obj.last_modified = now - timedelta(days=200) + objects.append(obj) + + client.list_objects.return_value = objects + expired = list_expired_objects(client, "test-bucket", 90, batch_size=3, now=now) + assert len(expired) == 3 + + def test_handles_list_error(self): + client = MagicMock() + client.list_objects.side_effect = Exception("connection error") + expired = list_expired_objects(client, "test-bucket", 90) + assert expired == [] + + +class TestDeleteExpiredObjects: + def test_deletes_all(self): + client = MagicMock() + count = delete_expired_objects(client, "test-bucket", ["a.json", "b.json"]) + assert count == 2 + assert client.remove_object.call_count == 2 + + def test_handles_partial_failure(self): + client = MagicMock() + client.remove_object.side_effect = [None, Exception("fail"), None] + count = delete_expired_objects(client, "test-bucket", ["a", "b", "c"]) + assert count == 2 + + +class TestCleanupBucket: + def test_full_cleanup_flow(self): + client = MagicMock() + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + old_obj = MagicMock() + old_obj.object_name = "expired.json" + old_obj.last_modified = now - timedelta(days=200) + client.list_objects.return_value = [old_obj] + + policy = RetentionPolicy("stonks-raw-market", 90) + result = cleanup_bucket(client, policy, now=now) + + assert result.bucket_name == "stonks-raw-market" + assert result.objects_scanned == 1 + assert result.objects_deleted == 1 + + def test_no_expired_objects(self): + client = MagicMock() + client.list_objects.return_value = [] + + policy = RetentionPolicy("stonks-raw-news", 180) + result = cleanup_bucket(client, policy) + + assert result.objects_scanned == 0 + assert result.objects_deleted == 0 diff --git a/tests/test_risk_engine.py b/tests/test_risk_engine.py new file mode 100644 index 0000000..a7b3835 --- /dev/null +++ b/tests/test_risk_engine.py @@ -0,0 +1,413 @@ +"""Tests for the portfolio and account risk configuration model and enforcement.""" + +from datetime import datetime, timedelta, timezone + +from services.risk.engine import ( + AccountRiskState, + DailyLossLimits, + DEFAULT_RISK_CONFIG, + NewsShockLockout, + OperatorApproval, + PortfolioRiskConfig, + PositionLimits, + ProposedOrder, + RiskCheckDetail, + RiskCheckResult, + RiskEvaluation, + SectorExposureLimits, + SymbolCooldown, + TradingMode, + evaluate_order, +) + + +def test_default_risk_config_is_paper_mode(): + """Default config should be paper trading mode.""" + cfg = PortfolioRiskConfig() + assert cfg.trading_mode == TradingMode.PAPER + assert cfg.active is True + + +def test_position_limits_defaults(): + limits = PositionLimits() + assert limits.max_position_pct == 0.05 + assert limits.max_position_value == 10_000.0 + assert limits.max_shares_per_order == 1000.0 + + +def test_sector_exposure_defaults(): + limits = SectorExposureLimits() + assert limits.max_sector_pct == 0.25 + assert limits.max_sectors == 10 + + +def test_daily_loss_defaults(): + limits = DailyLossLimits() + assert limits.max_daily_loss_pct == 0.02 + assert limits.max_daily_loss_value == 1_000.0 + assert limits.max_daily_trades == 20 + + +def test_news_shock_lockout_defaults(): + lockout = NewsShockLockout() + assert lockout.enabled is True + assert lockout.lockout_minutes == 60 + assert lockout.impact_threshold == 0.80 + assert "earnings" in lockout.catalyst_types + + +def test_operator_approval_defaults(): + approval = OperatorApproval() + assert approval.require_approval_for_live is True + assert approval.auto_approve_paper is True + assert approval.approval_timeout_minutes == 30 + + +def test_symbol_cooldown_defaults(): + cooldown = SymbolCooldown() + assert cooldown.cooldown_minutes == 15 + assert cooldown.max_open_positions_per_symbol == 1 + + +def test_portfolio_config_roundtrip_json(): + """Config should survive serialization to JSON and back.""" + cfg = PortfolioRiskConfig( + name="test-profile", + trading_mode=TradingMode.LIVE, + position_limits=PositionLimits(max_position_pct=0.10), + daily_loss=DailyLossLimits(max_daily_trades=5), + ) + data = cfg.to_db_json() + restored = PortfolioRiskConfig.from_db_json(data) + + assert restored.name == "test-profile" + assert restored.trading_mode == TradingMode.LIVE + assert restored.position_limits.max_position_pct == 0.10 + assert restored.daily_loss.max_daily_trades == 5 + # Nested defaults should survive + assert restored.sector_exposure.max_sector_pct == 0.25 + assert restored.news_shock.enabled is True + + +def test_account_risk_state_defaults(): + state = AccountRiskState(account_id="test-acct") + assert state.portfolio_value == 0.0 + assert state.daily_trade_count == 0 + assert state.positions_by_symbol == {} + assert state.positions_by_sector == {} + assert state.locked_symbols == {} + + +def test_account_risk_state_with_positions(): + state = AccountRiskState( + account_id="acct-1", + portfolio_value=100_000.0, + cash=50_000.0, + daily_pnl=-500.0, + daily_trade_count=3, + positions_by_symbol={"AAPL": 10_000.0, "MSFT": 5_000.0}, + positions_by_sector={"Technology": 15_000.0}, + ) + assert state.positions_by_symbol["AAPL"] == 10_000.0 + assert state.positions_by_sector["Technology"] == 15_000.0 + assert state.daily_pnl == -500.0 + + +def test_risk_evaluation_passed_property(): + """passed should be True only when eligible and no rejections.""" + passing = RiskEvaluation( + ticker="AAPL", + eligible=True, + allowed_mode=TradingMode.PAPER, + checks=[ + RiskCheckDetail(check_name="position_size", result=RiskCheckResult.PASS), + ], + ) + assert passing.passed is True + + failing = RiskEvaluation( + ticker="AAPL", + eligible=False, + allowed_mode=TradingMode.DISABLED, + rejection_reasons=["max_daily_loss_exceeded"], + checks=[ + RiskCheckDetail( + check_name="daily_loss", + result=RiskCheckResult.FAIL, + message="Daily loss limit exceeded", + threshold=0.02, + actual=0.03, + ), + ], + ) + assert failing.passed is False + + +def test_risk_evaluation_captures_config_snapshot(): + """Evaluation should be able to store the config used for reproducibility.""" + cfg = PortfolioRiskConfig(name="snapshot-test") + state = AccountRiskState(account_id="acct-1", portfolio_value=50_000.0) + + evaluation = RiskEvaluation( + ticker="TSLA", + eligible=True, + allowed_mode=TradingMode.PAPER, + config_snapshot=cfg, + state_snapshot=state, + ) + assert evaluation.config_snapshot is not None + assert evaluation.config_snapshot.name == "snapshot-test" + assert evaluation.state_snapshot is not None + assert evaluation.state_snapshot.portfolio_value == 50_000.0 + + +def test_trading_mode_disabled(): + """DISABLED mode should be available for halting all trading.""" + cfg = PortfolioRiskConfig(trading_mode=TradingMode.DISABLED) + assert cfg.trading_mode == TradingMode.DISABLED + + +def test_default_risk_config_singleton(): + """Module-level default should be a valid paper config.""" + assert DEFAULT_RISK_CONFIG.trading_mode == TradingMode.PAPER + assert DEFAULT_RISK_CONFIG.name == "default" + + +# =================================================================== +# Enforcement logic tests (hard blocks) +# =================================================================== + + +def _make_config(**overrides) -> PortfolioRiskConfig: + return PortfolioRiskConfig( + trading_mode=overrides.get("trading_mode", TradingMode.PAPER), + position_limits=overrides.get("position_limits", PositionLimits()), + sector_exposure=overrides.get("sector_exposure", SectorExposureLimits()), + daily_loss=overrides.get("daily_loss", DailyLossLimits()), + news_shock=overrides.get("news_shock", NewsShockLockout()), + symbol_cooldown=overrides.get("symbol_cooldown", SymbolCooldown()), + ) + + +def _make_state(**overrides) -> AccountRiskState: + return AccountRiskState( + account_id=overrides.get("account_id", "test-acct"), + portfolio_value=overrides.get("portfolio_value", 100_000.0), + cash=overrides.get("cash", 50_000.0), + daily_pnl=overrides.get("daily_pnl", 0.0), + daily_trade_count=overrides.get("daily_trade_count", 0), + positions_by_symbol=overrides.get("positions_by_symbol", {}), + positions_by_sector=overrides.get("positions_by_sector", {}), + last_trade_times=overrides.get("last_trade_times", {}), + locked_symbols=overrides.get("locked_symbols", {}), + ) + + +# --- Trading mode gate --- + + +def test_evaluate_order_disabled_mode_blocks(): + """Orders are rejected when trading mode is DISABLED.""" + config = _make_config(trading_mode=TradingMode.DISABLED) + order = ProposedOrder(ticker="AAPL", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, _make_state()) + assert result.passed is False + assert any("disabled" in r.lower() for r in result.rejection_reasons) + + +def test_evaluate_order_paper_mode_passes(): + """A clean order in paper mode should pass all checks.""" + config = _make_config() + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, _make_state()) + assert result.passed is True + assert result.allowed_mode == TradingMode.PAPER + + +# --- Max position size --- + + +def test_position_value_exceeded(): + config = _make_config(position_limits=PositionLimits(max_position_value=5000)) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10) + result = evaluate_order(order, config, _make_state()) + assert result.passed is False + assert any(c.check_name == "max_position_value" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_position_value_includes_existing(): + """Existing position value is added to the new order value.""" + config = _make_config(position_limits=PositionLimits(max_position_value=5000)) + state = _make_state(positions_by_symbol={"AAPL": 3000.0}) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=2500, quantity=5) + result = evaluate_order(order, config, state) + assert result.passed is False + fail_check = next(c for c in result.checks if c.check_name == "max_position_value") + assert fail_check.actual == 5500.0 + + +def test_position_pct_exceeded(): + config = _make_config(position_limits=PositionLimits(max_position_pct=0.05)) + state = _make_state(portfolio_value=100_000) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10) + result = evaluate_order(order, config, state) + assert any(c.check_name == "max_position_pct" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_max_shares_exceeded(): + config = _make_config(position_limits=PositionLimits(max_shares_per_order=100)) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=200) + result = evaluate_order(order, config, _make_state()) + assert any(c.check_name == "max_shares_per_order" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +# --- Sector exposure --- + + +def test_sector_exposure_exceeded(): + config = _make_config(sector_exposure=SectorExposureLimits(max_sector_pct=0.25)) + state = _make_state( + portfolio_value=100_000, + positions_by_sector={"Technology": 20_000.0}, + ) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10) + result = evaluate_order(order, config, state) + assert any(c.check_name == "sector_exposure" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_sector_exposure_no_sector_warns(): + """Missing sector on order produces a warning, not a failure.""" + config = _make_config() + order = ProposedOrder(ticker="AAPL", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, _make_state()) + sector_check = next(c for c in result.checks if c.check_name == "sector_exposure") + assert sector_check.result == RiskCheckResult.WARN + + +# --- Daily loss limits --- + + +def test_daily_loss_pct_exceeded(): + config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_pct=0.02)) + state = _make_state(portfolio_value=100_000, daily_pnl=-2500) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + assert any(c.check_name == "daily_loss_pct" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_daily_loss_value_exceeded(): + config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_value=500)) + state = _make_state(daily_pnl=-600) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + assert any(c.check_name == "daily_loss_value" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_daily_trade_count_exceeded(): + config = _make_config(daily_loss=DailyLossLimits(max_daily_trades=5)) + state = _make_state(daily_trade_count=5) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + assert any(c.check_name == "daily_trade_count" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_positive_pnl_does_not_trigger_loss_limit(): + """Positive P&L should not trigger daily loss checks.""" + config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_pct=0.02)) + state = _make_state(portfolio_value=100_000, daily_pnl=5000) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + loss_checks = [c for c in result.checks if c.check_name.startswith("daily_loss")] + assert all(c.result == RiskCheckResult.PASS for c in loss_checks) + + +# --- News-shock lockout --- + + +def test_news_shock_lockout_blocks(): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + lockout_expiry = now + timedelta(minutes=30) + state = _make_state(locked_symbols={"AAPL": lockout_expiry}) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, _make_config(), state, now=now) + assert any(c.check_name == "news_shock_lockout" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_news_shock_lockout_expired_passes(): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + lockout_expiry = now - timedelta(minutes=5) # already expired + state = _make_state(locked_symbols={"AAPL": lockout_expiry}) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, _make_config(), state, now=now) + lockout_check = next(c for c in result.checks if c.check_name == "news_shock_lockout") + assert lockout_check.result == RiskCheckResult.PASS + + +def test_news_shock_lockout_disabled_passes(): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + lockout_expiry = now + timedelta(minutes=30) + config = _make_config(news_shock=NewsShockLockout(enabled=False)) + state = _make_state(locked_symbols={"AAPL": lockout_expiry}) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state, now=now) + lockout_check = next(c for c in result.checks if c.check_name == "news_shock_lockout") + assert lockout_check.result == RiskCheckResult.PASS + + +# --- Symbol cooldown --- + + +def test_symbol_cooldown_blocks(): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + last_trade = now - timedelta(minutes=5) # 5 min ago, default cooldown is 15 + state = _make_state(last_trade_times={"AAPL": last_trade}) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, _make_config(), state, now=now) + assert any(c.check_name == "symbol_cooldown" and c.result == RiskCheckResult.FAIL for c in result.checks) + + +def test_symbol_cooldown_expired_passes(): + now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + last_trade = now - timedelta(minutes=20) # 20 min ago, cooldown is 15 + state = _make_state(last_trade_times={"AAPL": last_trade}) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, _make_config(), state, now=now) + cooldown_check = next(c for c in result.checks if c.check_name == "symbol_cooldown") + assert cooldown_check.result == RiskCheckResult.PASS + + +# --- Combined scenarios --- + + +def test_multiple_failures_all_captured(): + """When multiple checks fail, all rejection reasons are captured.""" + config = _make_config( + position_limits=PositionLimits(max_position_value=500), + daily_loss=DailyLossLimits(max_daily_loss_value=100), + ) + state = _make_state(daily_pnl=-200) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + assert result.passed is False + assert len(result.rejection_reasons) >= 2 + + +def test_evaluation_captures_snapshots(): + """Config and state snapshots are stored for reproducibility.""" + config = _make_config() + state = _make_state(portfolio_value=75_000) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + assert result.config_snapshot is not None + assert result.state_snapshot is not None + assert result.state_snapshot.portfolio_value == 75_000 + + +def test_fail_closed_no_state(): + """With zero portfolio value, position pct check should fail-closed for non-zero orders.""" + config = _make_config() + state = _make_state(portfolio_value=0.0) + order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) + result = evaluate_order(order, config, state) + # position_pct = 1.0 when portfolio is 0 and order value > 0 → exceeds 0.05 + assert any(c.check_name == "max_position_pct" and c.result == RiskCheckResult.FAIL for c in result.checks) diff --git a/tests/test_rollups.py b/tests/test_rollups.py new file mode 100644 index 0000000..6fe3c99 --- /dev/null +++ b/tests/test_rollups.py @@ -0,0 +1,173 @@ +"""Tests for sector and market rollup aggregation. + +Tests the pure rollup logic (no DB required). + +Requirements: 6.3, 6.4, 6.5 +""" +from datetime import datetime, timezone + +from services.aggregation.rollups import ( + CompanyTrendRow, + rollup_trends, + _build_rollup_disagreement, + _derive_rollup_direction, +) +from services.shared.schemas import TrendDirection, TrendWindow + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + +def _make_trend( + ticker: str = "AAPL", + sector: str = "Technology", + window: str = "7d", + direction: str = "bullish", + strength: float = 0.6, + confidence: float = 0.8, + contradiction: float = 0.1, + catalysts: list[str] | None = None, + risks: list[str] | None = None, + supporting: list[str] | None = None, + opposing: list[str] | None = None, +) -> CompanyTrendRow: + return CompanyTrendRow( + entity_id=ticker, + sector=sector, + window=window, + trend_direction=direction, + trend_strength=strength, + confidence=confidence, + contradiction_score=contradiction, + dominant_catalysts=catalysts or ["earnings"], + material_risks=risks or [], + top_supporting_evidence=supporting or ["doc-1"], + top_opposing_evidence=opposing or [], + ) + + +# --------------------------------------------------------------------------- +# rollup_trends +# --------------------------------------------------------------------------- + + +def test_rollup_empty(): + summary = rollup_trends([], "sector", "Technology", "7d", NOW) + assert summary.entity_type == "sector" + assert summary.entity_id == "Technology" + assert summary.trend_direction == TrendDirection.NEUTRAL + assert summary.trend_strength == 0.0 + assert summary.confidence == 0.0 + + +def test_rollup_single_bullish(): + trends = [_make_trend("AAPL", direction="bullish", strength=0.7, confidence=0.9)] + summary = rollup_trends(trends, "sector", "Technology", "7d", NOW) + assert summary.trend_direction == TrendDirection.BULLISH + assert summary.trend_strength > 0 + assert summary.confidence > 0 + assert summary.window == TrendWindow.SEVEN_DAY + + +def test_rollup_mixed_directions(): + trends = [ + _make_trend("AAPL", direction="bullish", strength=0.6, confidence=0.8), + _make_trend("MSFT", direction="bearish", strength=0.6, confidence=0.8), + ] + summary = rollup_trends(trends, "sector", "Technology", "7d", NOW) + # Equal and opposite → neutral or mixed + assert summary.trend_direction in (TrendDirection.NEUTRAL, TrendDirection.MIXED) + + +def test_rollup_confidence_weighted(): + """Higher-confidence company should dominate the rollup direction.""" + trends = [ + _make_trend("AAPL", direction="bullish", strength=0.8, confidence=0.95), + _make_trend("MSFT", direction="bearish", strength=0.3, confidence=0.2), + ] + summary = rollup_trends(trends, "sector", "Technology", "7d", NOW) + assert summary.trend_direction == TrendDirection.BULLISH + + +def test_rollup_catalysts_aggregated(): + trends = [ + _make_trend("AAPL", catalysts=["earnings", "product"], confidence=0.8), + _make_trend("MSFT", catalysts=["product", "macro"], confidence=0.6), + ] + summary = rollup_trends(trends, "sector", "Technology", "7d", NOW) + # "product" appears in both → should be top catalyst + assert "product" in summary.dominant_catalysts + + +def test_rollup_risks_deduplicated(): + trends = [ + _make_trend("AAPL", risks=["regulatory risk", "supply chain"], confidence=0.8), + _make_trend("MSFT", risks=["Regulatory Risk", "tariffs"], confidence=0.6), + ] + summary = rollup_trends(trends, "sector", "Technology", "7d", NOW) + risk_lower = [r.lower() for r in summary.material_risks] + assert risk_lower.count("regulatory risk") == 1 + + +def test_rollup_evidence_collected(): + trends = [ + _make_trend("AAPL", supporting=["doc-1", "doc-2"], opposing=["doc-3"]), + _make_trend("MSFT", supporting=["doc-4"], opposing=["doc-5"]), + ] + summary = rollup_trends(trends, "market", "all", "7d", NOW) + assert "doc-1" in summary.top_supporting_evidence + assert "doc-4" in summary.top_supporting_evidence + assert "doc-3" in summary.top_opposing_evidence + + +def test_rollup_market_entity_type(): + trends = [_make_trend("AAPL"), _make_trend("JPM", sector="Financials")] + summary = rollup_trends(trends, "market", "all", "7d", NOW) + assert summary.entity_type == "market" + assert summary.entity_id == "all" + + +# --------------------------------------------------------------------------- +# _derive_rollup_direction +# --------------------------------------------------------------------------- + + +def test_derive_direction_bullish(): + assert _derive_rollup_direction(0.5, 0.0) == TrendDirection.BULLISH + + +def test_derive_direction_bearish(): + assert _derive_rollup_direction(-0.5, 0.0) == TrendDirection.BEARISH + + +def test_derive_direction_neutral(): + assert _derive_rollup_direction(0.05, 0.0) == TrendDirection.NEUTRAL + + +def test_derive_direction_mixed_high_contradiction(): + assert _derive_rollup_direction(0.1, 0.2) == TrendDirection.MIXED + + +# --------------------------------------------------------------------------- +# _build_rollup_disagreement +# --------------------------------------------------------------------------- + + +def test_disagreement_no_conflict(): + trends = [ + _make_trend("AAPL", direction="bullish"), + _make_trend("MSFT", direction="bullish"), + ] + details = _build_rollup_disagreement(trends, "Technology") + assert details == [] + + +def test_disagreement_with_conflict(): + trends = [ + _make_trend("AAPL", direction="bullish", confidence=0.8), + _make_trend("MSFT", direction="bearish", confidence=0.7), + ] + details = _build_rollup_disagreement(trends, "Technology") + assert len(details) == 1 + assert details[0].dimension == "company_direction" + assert "AAPL" in details[0].positive_doc_ids + assert "MSFT" in details[0].negative_doc_ids diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..105a578 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,131 @@ +"""Tests for scheduler polling logic.""" +from datetime import datetime, timedelta + +from services.scheduler.app import ( + DEFAULT_CADENCES, + MAX_RETRY_COUNT, + build_job_payload, + compute_backoff, + get_cadence_for_source, + is_source_due, +) + + +class TestGetCadenceForSource: + def test_default_cadence_market_api(self): + assert get_cadence_for_source("market_api", None) == 60 + + def test_default_cadence_news_api(self): + assert get_cadence_for_source("news_api", {}) == 300 + + def test_default_cadence_unknown_type(self): + assert get_cadence_for_source("unknown", None) == 600 + + def test_override_from_config(self): + config = {"polling_interval_seconds": 120} + assert get_cadence_for_source("market_api", config) == 120 + + def test_override_minimum_clamp(self): + config = {"polling_interval_seconds": 5} + assert get_cadence_for_source("market_api", config) == 10 + + def test_invalid_override_falls_back(self): + config = {"polling_interval_seconds": "not_a_number"} + assert get_cadence_for_source("news_api", config) == DEFAULT_CADENCES["news_api"] + + +class TestComputeBackoff: + def test_first_retry(self): + assert compute_backoff(0) == 60 + + def test_second_retry(self): + assert compute_backoff(1) == 120 + + def test_capped_at_max(self): + assert compute_backoff(20) == 3600 + + +class TestIsSourceDue: + def _now(self): + return datetime(2026, 4, 11, 12, 0, 0) + + def test_never_run_is_due(self): + assert is_source_due("market_api", None, None, None, 0, None, self._now()) + + def test_completed_within_cadence_not_due(self): + last = self._now() - timedelta(seconds=30) + assert not is_source_due("market_api", None, last, "completed", 0, None, self._now()) + + def test_completed_past_cadence_is_due(self): + last = self._now() - timedelta(seconds=120) + assert is_source_due("market_api", None, last, "completed", 0, None, self._now()) + + def test_running_not_due(self): + last = self._now() - timedelta(seconds=5) + assert not is_source_due("market_api", None, last, "running", 0, None, self._now()) + + def test_failed_within_backoff_not_due(self): + last = self._now() - timedelta(seconds=30) + next_retry = self._now() + timedelta(seconds=30) + assert not is_source_due("market_api", None, last, "failed", 1, next_retry, self._now()) + + def test_failed_past_backoff_is_due(self): + last = self._now() - timedelta(seconds=120) + next_retry = self._now() - timedelta(seconds=10) + assert is_source_due("market_api", None, last, "failed", 1, next_retry, self._now()) + + def test_failed_max_retries_not_due(self): + last = self._now() - timedelta(seconds=120) + assert not is_source_due( + "market_api", None, last, "failed", MAX_RETRY_COUNT, None, self._now() + ) + + def test_custom_cadence_respected(self): + config = {"polling_interval_seconds": 600} + last = self._now() - timedelta(seconds=300) + assert not is_source_due("market_api", config, last, "completed", 0, None, self._now()) + + last_old = self._now() - timedelta(seconds=700) + assert is_source_due("market_api", config, last_old, "completed", 0, None, self._now()) + + +class TestBuildJobPayload: + def test_payload_structure(self): + source = { + "source_id": "sid-1", + "company_id": "cid-1", + "ticker": "AAPL", + "legal_name": "Apple Inc.", + "source_type": "news_api", + "source_name": "NewsAPI", + "config": {"endpoint": "/v2/everything"}, + "credibility_score": 0.8, + } + now = datetime(2026, 4, 11, 12, 0, 0) + job = build_job_payload(source, ["Apple", "iPhone"], now) + + assert job["source_id"] == "sid-1" + assert job["company_id"] == "cid-1" + assert job["ticker"] == "AAPL" + assert job["legal_name"] == "Apple Inc." + assert job["aliases"] == ["Apple", "iPhone"] + assert job["source_type"] == "news_api" + assert job["config"] == {"endpoint": "/v2/everything"} + assert job["credibility_score"] == 0.8 + assert job["scheduled_at"] == now.isoformat() + + def test_payload_null_config(self): + source = { + "source_id": "sid-2", + "company_id": "cid-2", + "ticker": "MSFT", + "legal_name": "Microsoft Corp.", + "source_type": "market_api", + "source_name": "Polygon", + "config": None, + "credibility_score": None, + } + job = build_job_payload(source, [], datetime(2026, 4, 11, 12, 0, 0)) + assert job["config"] == {} + assert job["credibility_score"] == 0.5 + assert job["aliases"] == [] diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..0b834ca --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,212 @@ +"""Tests for shared MinIO storage utilities. + +Validates bucket mapping, path building, storage refs, bucket creation, +artifact upload, and download from services.shared.storage. + +Requirements: 3.1, 3.2, 3.3, 9.1 +""" +from datetime import datetime, timezone +from unittest.mock import MagicMock + +from services.shared.storage import ( + ALL_BUCKETS, + bucket_for_source, + build_artifact_path, + download_artifact, + ensure_buckets, + storage_ref, + upload_artifact, + upload_html_artifact, + upload_normalized_text, + upload_parser_output, + upload_raw_artifact, +) + + +class TestBucketForSource: + def test_market_api(self): + assert bucket_for_source("market_api") == "stonks-raw-market" + + def test_news_api(self): + assert bucket_for_source("news_api") == "stonks-raw-news" + + def test_filings_api(self): + assert bucket_for_source("filings_api") == "stonks-raw-filings" + + def test_web_scrape(self): + assert bucket_for_source("web_scrape") == "stonks-raw-news" + + def test_broker(self): + assert bucket_for_source("broker") == "stonks-raw-market" + + def test_unknown_defaults_to_market(self): + assert bucket_for_source("unknown_type") == "stonks-raw-market" + + +class TestBuildArtifactPath: + def test_default_path_format(self): + ts = datetime(2026, 4, 11, 14, 30, 0, tzinfo=timezone.utc) + path = build_artifact_path("news_api", "AAPL", "doc-123", timestamp=ts) + assert path == "news_api/AAPL/2026/04/11/doc-123/raw.json" + + def test_custom_artifact_name_and_ext(self): + ts = datetime(2026, 1, 5, 0, 0, 0, tzinfo=timezone.utc) + path = build_artifact_path( + "web_scrape", "MSFT", "doc-456", + artifact_name="raw", ext="html", timestamp=ts, + ) + assert path == "web_scrape/MSFT/2026/01/05/doc-456/raw.html" + + def test_uses_utc_now_when_no_timestamp(self): + path = build_artifact_path("market_api", "GOOG", "run-1") + # Just verify it has the expected structure + parts = path.split("/") + assert parts[0] == "market_api" + assert parts[1] == "GOOG" + assert len(parts) == 7 # source/ticker/yyyy/mm/dd/doc_id/file + + +class TestStorageRef: + def test_builds_s3_uri(self): + ref = storage_ref("stonks-raw-news", "news_api/AAPL/2026/04/11/doc-1/raw.json") + assert ref == "s3://stonks-raw-news/news_api/AAPL/2026/04/11/doc-1/raw.json" + + +class TestEnsureBuckets: + def test_creates_missing_buckets(self): + client = MagicMock() + client.bucket_exists.return_value = False + created = ensure_buckets(client, ["bucket-a", "bucket-b"]) + assert created == ["bucket-a", "bucket-b"] + assert client.make_bucket.call_count == 2 + + def test_skips_existing_buckets(self): + client = MagicMock() + client.bucket_exists.return_value = True + created = ensure_buckets(client, ["bucket-a"]) + assert created == [] + client.make_bucket.assert_not_called() + + def test_defaults_to_all_buckets(self): + client = MagicMock() + client.bucket_exists.return_value = True + ensure_buckets(client) + assert client.bucket_exists.call_count == len(ALL_BUCKETS) + + +class TestUploadArtifact: + def test_uploads_and_returns_ref(self): + client = MagicMock() + ref = upload_artifact( + client, "stonks-raw-news", "path/to/obj.json", + b'{"key": "value"}', content_type="application/json", + ) + assert ref == "s3://stonks-raw-news/path/to/obj.json" + client.put_object.assert_called_once() + args, kwargs = client.put_object.call_args + assert args[0] == "stonks-raw-news" + assert args[1] == "path/to/obj.json" + assert kwargs["length"] == len(b'{"key": "value"}') + assert kwargs["content_type"] == "application/json" + + def test_passes_metadata(self): + client = MagicMock() + upload_artifact( + client, "stonks-raw-market", "p.json", + b"data", metadata={"ticker": "AAPL"}, + ) + _, kwargs = client.put_object.call_args + assert kwargs["metadata"] == {"ticker": "AAPL"} + + +class TestUploadRawArtifact: + def test_market_api_json(self): + client = MagicMock() + ts = datetime(2026, 4, 11, 0, 0, 0, tzinfo=timezone.utc) + ref = upload_raw_artifact( + client, source_type="market_api", ticker="AAPL", + document_id="run-1", data=b'{"bars":[]}', + artifact_type="raw_json", timestamp=ts, + ) + assert "stonks-raw-market" in ref + assert "market_api/AAPL/2026/04/11/run-1/raw.json" in ref + _, kwargs = client.put_object.call_args + assert kwargs["content_type"] == "application/json" + + def test_web_scrape_html(self): + client = MagicMock() + ts = datetime(2026, 3, 1, 0, 0, 0, tzinfo=timezone.utc) + ref = upload_raw_artifact( + client, source_type="web_scrape", ticker="TSLA", + document_id="doc-5", data=b"", + artifact_type="raw_html", timestamp=ts, + ) + assert "stonks-raw-news" in ref + assert "raw.html" in ref + _, kwargs = client.put_object.call_args + assert kwargs["content_type"] == "text/html" + + +class TestUploadHtmlArtifact: + def test_stores_in_web_scrape_path(self): + client = MagicMock() + ts = datetime(2026, 6, 15, 0, 0, 0, tzinfo=timezone.utc) + ref = upload_html_artifact( + client, ticker="NVDA", document_id="page-1", + html_bytes=b"test", timestamp=ts, + ) + assert "stonks-raw-news" in ref + assert "web_scrape/NVDA/2026/06/15/page-1/raw.html" in ref + + +class TestDownloadArtifact: + def test_reads_and_returns_bytes(self): + client = MagicMock() + mock_response = MagicMock() + mock_response.read.return_value = b"file contents" + client.get_object.return_value = mock_response + + data = download_artifact(client, "stonks-raw-news", "path/to/obj.json") + assert data == b"file contents" + client.get_object.assert_called_once_with("stonks-raw-news", "path/to/obj.json") + mock_response.close.assert_called_once() + mock_response.release_conn.assert_called_once() + + +class TestUploadNormalizedText: + def test_stores_in_normalized_bucket(self): + client = MagicMock() + ts = datetime(2026, 4, 11, 0, 0, 0, tzinfo=timezone.utc) + ref = upload_normalized_text( + client, ticker="AAPL", document_id="doc-1", + text_bytes=b"Normalized article text here.", + timestamp=ts, + ) + assert "stonks-normalized" in ref + assert "parsed/AAPL/2026/04/11/doc-1/normalized.txt" in ref + _, kwargs = client.put_object.call_args + assert kwargs["content_type"] == "text/plain" + + def test_path_uses_current_time_when_no_timestamp(self): + client = MagicMock() + ref = upload_normalized_text( + client, ticker="MSFT", document_id="doc-2", + text_bytes=b"Some text.", + ) + assert "stonks-normalized" in ref + assert "normalized.txt" in ref + + +class TestUploadParserOutput: + def test_stores_json_in_normalized_bucket(self): + client = MagicMock() + ts = datetime(2026, 4, 11, 0, 0, 0, tzinfo=timezone.utc) + ref = upload_parser_output( + client, ticker="AAPL", document_id="doc-1", + output_bytes=b'{"quality_score": 0.8}', + timestamp=ts, + ) + assert "stonks-normalized" in ref + assert "parsed/AAPL/2026/04/11/doc-1/parser_output.json" in ref + _, kwargs = client.put_object.call_args + assert kwargs["content_type"] == "application/json" diff --git a/tests/test_suppression.py b/tests/test_suppression.py new file mode 100644 index 0000000..2442374 --- /dev/null +++ b/tests/test_suppression.py @@ -0,0 +1,190 @@ +"""Tests for recommendation suppression logic (data quality checks). + +Requirements: 7.4 +""" +from datetime import datetime, timedelta, timezone + +from services.recommendation.suppression import ( + DataQualityContext, + SuppressionConfig, + SuppressionReason, + build_quality_context_from_summary, + evaluate_suppression, +) +from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow + +NOW = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) + + +def _make_summary(**overrides) -> TrendSummary: + defaults = dict( + entity_type="company", + entity_id="AAPL", + window=TrendWindow.SEVEN_DAY, + trend_direction=TrendDirection.BULLISH, + trend_strength=0.5, + confidence=0.65, + top_supporting_evidence=["doc1", "doc2", "doc3"], + top_opposing_evidence=[], + dominant_catalysts=["earnings"], + material_risks=["regulatory scrutiny"], + contradiction_score=0.1, + generated_at=NOW, + ) + defaults.update(overrides) + return TrendSummary(**defaults) + + +def _make_quality_ctx(**overrides) -> DataQualityContext: + defaults = dict( + total_documents=5, + valid_documents=4, + failed_documents=1, + avg_extraction_confidence=0.7, + newest_evidence_at=NOW - timedelta(hours=6), + source_types={"news_api", "filings_api"}, + ) + defaults.update(overrides) + return DataQualityContext(**defaults) + + +# --------------------------------------------------------------------------- +# No suppression for good quality data +# --------------------------------------------------------------------------- + + +def test_no_suppression_good_quality(): + summary = _make_summary() + ctx = _make_quality_ctx() + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.suppressed is False + assert result.reasons == [] + assert result.data_quality_score > 0.3 + + +# --------------------------------------------------------------------------- +# Suppression triggers +# --------------------------------------------------------------------------- + + +def test_suppressed_low_extraction_confidence(): + summary = _make_summary() + ctx = _make_quality_ctx(avg_extraction_confidence=0.2) + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.suppressed is True + assert SuppressionReason.LOW_DATA_CONFIDENCE in result.reasons + + +def test_suppressed_stale_evidence(): + summary = _make_summary() + ctx = _make_quality_ctx(newest_evidence_at=NOW - timedelta(days=10)) + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.suppressed is True + assert SuppressionReason.STALE_EVIDENCE in result.reasons + + +def test_suppressed_high_failure_rate(): + summary = _make_summary() + ctx = _make_quality_ctx(total_documents=10, failed_documents=6, valid_documents=4) + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.suppressed is True + assert SuppressionReason.HIGH_EXTRACTION_FAILURE_RATE in result.reasons + + +def test_suppressed_insufficient_valid_documents(): + summary = _make_summary( + top_supporting_evidence=["doc1"], + top_opposing_evidence=[], + ) + ctx = _make_quality_ctx(total_documents=1, valid_documents=1, failed_documents=0) + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.suppressed is True + assert SuppressionReason.INSUFFICIENT_VALID_DOCUMENTS in result.reasons + + +def test_suppressed_low_source_diversity(): + """When min_source_types > available source types, suppression fires.""" + summary = _make_summary() + ctx = _make_quality_ctx(source_types=set()) + config = SuppressionConfig(min_source_types=2) + result = evaluate_suppression(summary, ctx, config=config, reference_time=NOW) + assert result.suppressed is True + assert SuppressionReason.LOW_SOURCE_DIVERSITY in result.reasons + + +# --------------------------------------------------------------------------- +# Fallback to summary-based context +# --------------------------------------------------------------------------- + + +def test_fallback_context_from_summary(): + summary = _make_summary(confidence=0.7) + ctx = build_quality_context_from_summary(summary) + assert ctx.total_documents == 3 # 3 supporting + 0 opposing + assert ctx.valid_documents == 3 + assert ctx.avg_extraction_confidence == 0.7 + + +def test_no_suppression_with_summary_fallback(): + """When no quality context is provided, summary-based fallback is used.""" + summary = _make_summary(confidence=0.7) + # Default config has min_source_types=1, but fallback has empty source_types. + # With min_source_types=1 and empty source_types, LOW_SOURCE_DIVERSITY fires + # only when total_documents > 0. But default min_source_types is 1 and + # len(set()) = 0 < 1, so it would fire. Let's use a config that relaxes this. + config = SuppressionConfig(min_source_types=0) + result = evaluate_suppression(summary, config=config, reference_time=NOW) + assert result.suppressed is False + + +# --------------------------------------------------------------------------- +# Data quality score +# --------------------------------------------------------------------------- + + +def test_quality_score_high_for_good_data(): + summary = _make_summary() + ctx = _make_quality_ctx( + avg_extraction_confidence=0.85, + newest_evidence_at=NOW - timedelta(hours=1), + total_documents=10, + valid_documents=10, + failed_documents=0, + ) + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.data_quality_score > 0.7 + + +def test_quality_score_low_for_bad_data(): + summary = _make_summary() + ctx = _make_quality_ctx( + avg_extraction_confidence=0.1, + newest_evidence_at=NOW - timedelta(days=14), + total_documents=3, + valid_documents=1, + failed_documents=2, + ) + result = evaluate_suppression(summary, ctx, reference_time=NOW) + assert result.data_quality_score < 0.3 + + +# --------------------------------------------------------------------------- +# Custom config +# --------------------------------------------------------------------------- + + +def test_custom_config_stricter_thresholds(): + summary = _make_summary() + ctx = _make_quality_ctx(avg_extraction_confidence=0.5) + strict = SuppressionConfig(min_avg_extraction_confidence=0.6) + result = evaluate_suppression(summary, ctx, config=strict, reference_time=NOW) + assert result.suppressed is True + assert SuppressionReason.LOW_DATA_CONFIDENCE in result.reasons + + +def test_custom_config_relaxed_thresholds(): + summary = _make_summary() + ctx = _make_quality_ctx(avg_extraction_confidence=0.3) + relaxed = SuppressionConfig(min_avg_extraction_confidence=0.2) + result = evaluate_suppression(summary, ctx, config=relaxed, reference_time=NOW) + assert SuppressionReason.LOW_DATA_CONFIDENCE not in result.reasons diff --git a/tests/test_thesis_llm.py b/tests/test_thesis_llm.py new file mode 100644 index 0000000..084f468 --- /dev/null +++ b/tests/test_thesis_llm.py @@ -0,0 +1,113 @@ +"""Tests for the optional LLM thesis rewriting layer. + +Tests prompt construction and the rewrite function's fallback behavior. +""" +from __future__ import annotations + +import pytest + +from services.recommendation.thesis_llm import ( + THESIS_SYSTEM_PROMPT, + build_thesis_rewrite_prompt, + rewrite_thesis_with_llm, +) +from services.shared.config import OllamaConfig +from services.shared.schemas import ( + TrendDirection, + TrendSummary, + TrendWindow, +) + + +def _make_summary( + ticker: str = "AAPL", + direction: TrendDirection = TrendDirection.BULLISH, + strength: float = 0.5, + confidence: float = 0.65, + contradiction: float = 0.1, + catalysts: list[str] | None = None, + risks: list[str] | None = None, +) -> TrendSummary: + return TrendSummary( + entity_type="company", + entity_id=ticker, + window=TrendWindow.SEVEN_DAY, + trend_direction=direction, + trend_strength=strength, + confidence=confidence, + top_supporting_evidence=["doc1", "doc2"], + top_opposing_evidence=[], + dominant_catalysts=catalysts or ["earnings"], + material_risks=risks or ["regulatory scrutiny"], + contradiction_score=contradiction, + ) + + +DETERMINISTIC_THESIS = ( + "AAPL shows a bullish trend over the 7d window with strength 0.50 " + "and confidence 0.65. Dominant catalysts: earnings. " + "Key risks: regulatory scrutiny. " + "Based on 2 supporting and 0 opposing evidence documents. " + "Recommendation: BUY (paper eligible)." +) + + +# --------------------------------------------------------------------------- +# Prompt construction +# --------------------------------------------------------------------------- + + +def test_prompt_contains_deterministic_thesis(): + summary = _make_summary() + prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary) + assert DETERMINISTIC_THESIS in prompts["user"] + + +def test_prompt_system_is_thesis_system_prompt(): + summary = _make_summary() + prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary) + assert prompts["system"] == THESIS_SYSTEM_PROMPT + + +def test_prompt_includes_ticker_context(): + summary = _make_summary(ticker="MSFT") + prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary) + assert "MSFT" in prompts["user"] + + +def test_prompt_includes_catalysts(): + summary = _make_summary(catalysts=["product", "m_and_a"]) + prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary) + assert "product" in prompts["user"] + + +def test_prompt_includes_risks(): + summary = _make_summary(risks=["supply chain disruption"]) + prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary) + assert "supply chain disruption" in prompts["user"] + + +def test_prompt_includes_trend_metrics(): + summary = _make_summary(strength=0.72, confidence=0.88, contradiction=0.15) + prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary) + assert "0.72" in prompts["user"] + assert "0.88" in prompts["user"] + assert "0.15" in prompts["user"] + + +# --------------------------------------------------------------------------- +# Fallback behavior — LLM failure returns deterministic thesis +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rewrite_falls_back_on_connection_error(): + """When Ollama is unreachable, the deterministic thesis is returned.""" + summary = _make_summary() + config = OllamaConfig(base_url="http://localhost:99999", timeout=2) + result = await rewrite_thesis_with_llm( + deterministic_thesis=DETERMINISTIC_THESIS, + summary=summary, + config=config, + ) + assert result == DETERMINISTIC_THESIS diff --git a/tests/test_web_scrape_adapter.py b/tests/test_web_scrape_adapter.py new file mode 100644 index 0000000..93f1fca --- /dev/null +++ b/tests/test_web_scrape_adapter.py @@ -0,0 +1,147 @@ +"""Tests for the web scrape adapter. + +Validates URL normalization, HTML metadata extraction, body text extraction, +and adapter result construction. +""" +import pytest + +from services.adapters.web_scrape_adapter import ( + WebScrapeAdapter, + extract_body_text, + extract_metadata_from_html, +) +from services.shared.content import normalize_url + + +SAMPLE_HTML = """ + + + Apple Q2 Earnings Beat Expectations + + + + + + + + + +
+

Apple Q2 Earnings Beat Expectations

+

Apple Inc. reported quarterly revenue of $95 billion, exceeding analyst estimates.

+

The company saw strong growth in its services division and iPhone sales.

+
+
Copyright 2026 TechNews
+ +""" + +MINIMAL_HTML = """

Short content.

""" + + +class TestNormalizeUrl: + def test_basic_normalization(self): + assert normalize_url("HTTPS://Example.COM/path") == "https://example.com/path" + + def test_strips_trailing_slash(self): + assert normalize_url("https://example.com/path/") == "https://example.com/path" + + def test_strips_fragment(self): + result = normalize_url("https://example.com/path#section") + assert "#" not in result + + def test_preserves_query(self): + result = normalize_url("https://example.com/path?q=test") + assert result == "https://example.com/path?q=test" + + def test_preserves_non_standard_port(self): + result = normalize_url("https://example.com:8443/path") + assert ":8443" in result + + def test_root_path(self): + result = normalize_url("https://example.com") + assert result == "https://example.com/" + + +class TestExtractMetadataFromHtml: + def test_extracts_title(self): + meta = extract_metadata_from_html(SAMPLE_HTML, "https://technews.example.com/article") + assert meta["title"] == "Apple Q2 Earnings Beat Expectations" + + def test_extracts_author(self): + meta = extract_metadata_from_html(SAMPLE_HTML, "https://technews.example.com/article") + assert meta["author"] == "Jane Reporter" + + def test_extracts_publisher(self): + meta = extract_metadata_from_html(SAMPLE_HTML, "https://technews.example.com/article") + assert meta["publisher"] == "TechNews" + + def test_extracts_published_at(self): + meta = extract_metadata_from_html(SAMPLE_HTML, "https://technews.example.com/article") + assert meta["published_at"] == "2026-04-10T14:00:00Z" + + def test_extracts_canonical_url(self): + meta = extract_metadata_from_html(SAMPLE_HTML, "https://technews.example.com/article") + assert meta["canonical_url"] == "https://technews.example.com/apple-q2-earnings" + + def test_extracts_language(self): + meta = extract_metadata_from_html(SAMPLE_HTML, "https://technews.example.com/article") + assert meta["language"] == "en" + + def test_fallback_publisher_from_hostname(self): + meta = extract_metadata_from_html(MINIMAL_HTML, "https://example.com/page") + assert meta["publisher"] == "example.com" + + def test_fallback_title_empty(self): + meta = extract_metadata_from_html(MINIMAL_HTML, "https://example.com/page") + assert meta["title"] == "" + + +class TestExtractBodyText: + def test_extracts_article_content(self): + text = extract_body_text(SAMPLE_HTML) + assert "Apple Inc. reported quarterly revenue" in text + assert "strong growth" in text + + def test_strips_nav_and_footer(self): + text = extract_body_text(SAMPLE_HTML) + assert "Navigation links here" not in text + assert "Copyright 2026" not in text + + def test_strips_script_and_style(self): + html = "

Content

" + text = extract_body_text(html) + assert "alert" not in text + assert "Content" in text + + def test_minimal_html(self): + text = extract_body_text(MINIMAL_HTML) + assert "Short content." in text + + +class TestWebScrapeAdapterSourceType: + def test_source_type(self): + adapter = WebScrapeAdapter() + assert adapter.source_type() == "web_scrape" + + def test_bucket_name(self): + adapter = WebScrapeAdapter() + assert adapter.bucket_name() == "stonks-raw-news" + + +class TestWebScrapeAdapterErrorResult: + def test_error_on_no_urls(self): + adapter = WebScrapeAdapter() + result = adapter._error_result("AAPL", "No URLs configured", 0) + assert not result.ok + assert result.error == "No URLs configured" + assert result.source_type == "web_scrape" + assert result.ticker == "AAPL" + + +@pytest.mark.asyncio +async def test_fetch_no_urls_configured(): + adapter = WebScrapeAdapter() + result = await adapter.fetch("AAPL", {}) + assert not result.ok + assert result.error is not None + assert "No URLs configured" in result.error