292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""Extraction worker - sends documents to Ollama for structured intelligence extraction.
|
|
|
|
Orchestrates the full extraction pipeline for a single document:
|
|
1. Calls OllamaClient to get structured extraction
|
|
2. Uploads prompts, raw outputs, and validation reports to MinIO
|
|
3. Persists the final intelligence object and per-company impact records to PostgreSQL
|
|
4. Updates document status
|
|
|
|
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
|
|
import asyncpg
|
|
from minio import Minio
|
|
|
|
from services.extractor.client import ExtractionResponse
|
|
from services.extractor.metrics import collect_metrics, persist_metrics
|
|
from services.shared.metadata import (
|
|
persist_document_impact,
|
|
persist_document_intelligence,
|
|
update_document_status,
|
|
)
|
|
from services.shared.storage import (
|
|
upload_extraction_intelligence,
|
|
upload_extraction_prompt,
|
|
upload_extraction_raw_output,
|
|
upload_extraction_validation,
|
|
)
|
|
from services.shared.logging import Span
|
|
from services.shared.metrics import (
|
|
EXTRACTION_ATTEMPTS,
|
|
EXTRACTION_CONFIDENCE,
|
|
EXTRACTION_DURATION,
|
|
EXTRACTION_JOBS_TOTAL,
|
|
EXTRACTION_RETRIES,
|
|
EXTRACTION_TOKEN_ESTIMATE,
|
|
EXTRACTION_VALIDATION_ERRORS,
|
|
)
|
|
|
|
logger = logging.getLogger("extractor_worker")
|
|
|
|
|
|
@dataclass
|
|
class ExtractionPersistResult:
|
|
"""Result of persisting an extraction to storage and database."""
|
|
|
|
intelligence_id: str | None = None
|
|
prompt_ref: str | None = None
|
|
raw_output_ref: str | None = None
|
|
validation_ref: str | None = None
|
|
intelligence_ref: str | None = None
|
|
impact_ids: list[str] | None = None
|
|
metrics_id: str | None = None
|
|
success: bool = False
|
|
|
|
|
|
async def persist_extraction(
|
|
*,
|
|
pool: asyncpg.Pool,
|
|
minio_client: Minio,
|
|
document_id: str,
|
|
ticker: str,
|
|
extraction_response: ExtractionResponse,
|
|
company_id_map: dict[str, str] | None = None,
|
|
source_credibility: float = 0.5,
|
|
timestamp: datetime | None = None,
|
|
document_text_length: int = 0,
|
|
) -> ExtractionPersistResult:
|
|
"""Persist all extraction artifacts to MinIO and PostgreSQL.
|
|
|
|
Uploads prompts, raw model outputs, validation reports, and the final
|
|
intelligence object to MinIO. Persists the intelligence record and
|
|
per-company impact records to PostgreSQL. Updates document status.
|
|
Also collects and persists model performance metrics.
|
|
|
|
Args:
|
|
pool: PostgreSQL connection pool.
|
|
minio_client: MinIO client.
|
|
document_id: UUID of the source document.
|
|
ticker: Primary ticker for path construction.
|
|
extraction_response: Full response from OllamaClient.extract().
|
|
company_id_map: Optional mapping of ticker -> company UUID for impact records.
|
|
source_credibility: Credibility score to attach to the intelligence record.
|
|
timestamp: Override timestamp for MinIO paths (defaults to UTC now).
|
|
document_text_length: Length of the input document text for token estimation.
|
|
|
|
Returns:
|
|
ExtractionPersistResult with references to all persisted artifacts.
|
|
"""
|
|
ts = timestamp or datetime.now(timezone.utc)
|
|
result = ExtractionPersistResult()
|
|
company_id_map = company_id_map or {}
|
|
|
|
# 1. Upload prompt metadata to MinIO
|
|
prompt_payload = json.dumps({
|
|
"prompt_metadata": extraction_response.prompt_metadata,
|
|
"model": extraction_response.model,
|
|
}, indent=2).encode()
|
|
result.prompt_ref = upload_extraction_prompt(
|
|
minio_client, ticker, document_id, prompt_payload, timestamp=ts,
|
|
)
|
|
|
|
# 2. Upload raw outputs for each attempt
|
|
attempts_data: list[dict[str, object]] = []
|
|
for idx, attempt in enumerate(extraction_response.attempts):
|
|
attempt_record: dict[str, object] = {
|
|
"attempt_index": idx,
|
|
"raw_output": attempt.raw_output,
|
|
"error": attempt.error,
|
|
"duration_ms": attempt.duration_ms,
|
|
"model": attempt.model,
|
|
"retryable": attempt.retryable,
|
|
}
|
|
if attempt.validation:
|
|
attempt_record["validation"] = {
|
|
"valid": attempt.validation.valid,
|
|
"errors": attempt.validation.errors,
|
|
"warnings": attempt.validation.warnings,
|
|
}
|
|
attempts_data.append(attempt_record)
|
|
|
|
raw_output_payload = json.dumps({
|
|
"document_id": document_id,
|
|
"attempts": attempts_data,
|
|
"total_duration_ms": extraction_response.total_duration_ms,
|
|
"success": extraction_response.success,
|
|
}, indent=2).encode()
|
|
result.raw_output_ref = upload_extraction_raw_output(
|
|
minio_client, ticker, document_id, raw_output_payload, timestamp=ts,
|
|
)
|
|
|
|
# 3. Upload validation report
|
|
final_attempt = extraction_response.attempts[-1] if extraction_response.attempts else None
|
|
validation_payload = json.dumps({
|
|
"document_id": document_id,
|
|
"success": extraction_response.success,
|
|
"attempt_count": len(extraction_response.attempts),
|
|
"final_validation": {
|
|
"valid": final_attempt.validation.valid if final_attempt and final_attempt.validation else False,
|
|
"errors": final_attempt.validation.errors if final_attempt and final_attempt.validation else [],
|
|
"warnings": final_attempt.validation.warnings if final_attempt and final_attempt.validation else [],
|
|
} if final_attempt else None,
|
|
}, indent=2).encode()
|
|
result.validation_ref = upload_extraction_validation(
|
|
minio_client, ticker, document_id, validation_payload, timestamp=ts,
|
|
)
|
|
|
|
# 4. Determine validation status and persist intelligence
|
|
if extraction_response.success and extraction_response.result:
|
|
extraction = extraction_response.result
|
|
validation_status = "valid"
|
|
validation_errors: list[str] = []
|
|
|
|
# Upload final intelligence object to MinIO
|
|
intelligence_payload = json.dumps(
|
|
extraction.model_dump(mode="json"), indent=2,
|
|
).encode()
|
|
result.intelligence_ref = upload_extraction_intelligence(
|
|
minio_client, ticker, document_id, intelligence_payload, timestamp=ts,
|
|
)
|
|
|
|
# Persist to PostgreSQL
|
|
intel_id = await persist_document_intelligence(
|
|
pool,
|
|
document_id=document_id,
|
|
summary=extraction.summary,
|
|
macro_themes=extraction.macro_themes,
|
|
novelty_score=extraction.novelty_score,
|
|
source_credibility=source_credibility,
|
|
extraction_warnings=extraction.extraction_warnings,
|
|
confidence=extraction.confidence,
|
|
model_provider="ollama",
|
|
model_name=extraction_response.model,
|
|
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
|
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
|
raw_output_ref=result.raw_output_ref,
|
|
prompt_ref=result.prompt_ref,
|
|
validation_status=validation_status,
|
|
validation_errors=validation_errors,
|
|
retry_count=len(extraction_response.attempts) - 1,
|
|
)
|
|
result.intelligence_id = intel_id
|
|
|
|
# Persist per-company impact records
|
|
result.impact_ids = []
|
|
for company in extraction.companies:
|
|
cid = company_id_map.get(company.ticker)
|
|
if not cid:
|
|
logger.warning(
|
|
"No company_id for ticker %s in doc %s, skipping impact record",
|
|
company.ticker, document_id,
|
|
)
|
|
continue
|
|
impact_id = await persist_document_impact(
|
|
pool,
|
|
intelligence_id=intel_id,
|
|
company_id=cid,
|
|
ticker=company.ticker,
|
|
relevance=company.relevance,
|
|
sentiment=company.sentiment,
|
|
impact_score=company.impact_score,
|
|
impact_horizon=company.impact_horizon,
|
|
catalyst_type=company.catalyst_type,
|
|
key_facts=company.key_facts,
|
|
risks=company.risks,
|
|
evidence_spans=company.evidence_spans,
|
|
)
|
|
result.impact_ids.append(impact_id)
|
|
|
|
await update_document_status(pool, document_id=document_id, status="extracted")
|
|
result.success = True
|
|
logger.info(
|
|
"Extraction persisted for doc %s: intel=%s, impacts=%d",
|
|
document_id, intel_id, len(result.impact_ids),
|
|
)
|
|
else:
|
|
# Failed extraction — still persist the attempt data
|
|
all_errors: list[str] = []
|
|
for attempt in extraction_response.attempts:
|
|
if attempt.error:
|
|
all_errors.append(attempt.error)
|
|
|
|
intel_id = await persist_document_intelligence(
|
|
pool,
|
|
document_id=document_id,
|
|
summary="",
|
|
macro_themes=[],
|
|
novelty_score=0.0,
|
|
source_credibility=source_credibility,
|
|
extraction_warnings=["extraction_failed"],
|
|
confidence=0.0,
|
|
model_provider="ollama",
|
|
model_name=extraction_response.model,
|
|
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
|
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
|
raw_output_ref=result.raw_output_ref,
|
|
prompt_ref=result.prompt_ref,
|
|
validation_status="failed",
|
|
validation_errors=all_errors,
|
|
retry_count=len(extraction_response.attempts),
|
|
)
|
|
result.intelligence_id = intel_id
|
|
|
|
await update_document_status(pool, document_id=document_id, status="extraction_failed")
|
|
logger.warning(
|
|
"Extraction failed for doc %s after %d attempts: %s",
|
|
document_id, len(extraction_response.attempts), "; ".join(all_errors),
|
|
)
|
|
|
|
# Collect and persist model performance metrics
|
|
try:
|
|
metrics = collect_metrics(
|
|
extraction_response,
|
|
document_id=document_id,
|
|
ticker=ticker,
|
|
document_text_length=document_text_length,
|
|
)
|
|
metrics.recorded_at = ts
|
|
metrics_id = await persist_metrics(pool, metrics)
|
|
result.metrics_id = metrics_id
|
|
except Exception:
|
|
logger.exception("Failed to persist extraction metrics for doc %s", document_id)
|
|
|
|
# Prometheus metrics
|
|
EXTRACTION_ATTEMPTS.inc(len(extraction_response.attempts))
|
|
EXTRACTION_DURATION.observe(extraction_response.total_duration_ms / 1000.0)
|
|
retry_count = max(0, len(extraction_response.attempts) - 1)
|
|
if retry_count > 0:
|
|
EXTRACTION_RETRIES.inc(retry_count)
|
|
if extraction_response.success:
|
|
EXTRACTION_JOBS_TOTAL.labels(status="success").inc()
|
|
if extraction_response.result:
|
|
EXTRACTION_CONFIDENCE.observe(extraction_response.result.confidence)
|
|
else:
|
|
EXTRACTION_JOBS_TOTAL.labels(status="failed").inc()
|
|
# Count validation errors from final attempt
|
|
final = extraction_response.attempts[-1] if extraction_response.attempts else None
|
|
if final and final.validation and final.validation.errors:
|
|
EXTRACTION_VALIDATION_ERRORS.inc(len(final.validation.errors))
|
|
# Token estimates
|
|
if document_text_length > 0:
|
|
EXTRACTION_TOKEN_ESTIMATE.labels(direction="input").inc(document_text_length // 4)
|
|
if final and final.raw_output:
|
|
EXTRACTION_TOKEN_ESTIMATE.labels(direction="output").inc(len(final.raw_output) // 4)
|
|
|
|
return result
|