Files
stonks-oracle/tests/test_pbt_competitive.py
Celes Renata c85c0068a2 fix: clean up utcnow deprecation warnings, fix 12 failing tests, add CI/CD pipeline manifests
- Replace all datetime.utcnow() with datetime.now(tz=timezone.utc) across 8 files
- Fix 12 failing tests to match current implementation behavior
- Fix pytest_plugins in non-top-level conftest (moved to root conftest.py)
- Auto-fix 189 lint issues (import sorting, unused imports)
- Add CI/CD pipeline infrastructure (ARC, ArgoCD, Kargo manifests)
- Add values-beta.yaml and values-paper.yaml for staged deployments
- Update GitHub Actions workflow to use self-hosted-gremlin runners
- Add integration-test job to CI pipeline

Result: 1596 passed, 0 failed, 0 warnings
2026-04-18 03:59:28 +00:00

817 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Property-based tests for the competitive intelligence layer.
Feature: competitive-historical-patterns
Uses Hypothesis to validate correctness properties of the competitor registry
endpoints: persistence round-trip, query completeness/ordering, and soft-delete.
"""
from __future__ import annotations
import copy
import uuid
from datetime import datetime, timezone
from typing import Any
from hypothesis import given, settings
from hypothesis import strategies as st
from services.symbol_registry.competitors import (
VALID_RELATIONSHIP_TYPES,
VALID_SOURCES,
CompetitorRelationship,
CompetitorRelationshipCreate,
)
# ---------------------------------------------------------------------------
# Hypothesis strategies
# ---------------------------------------------------------------------------
_RELATIONSHIP_TYPES = list(VALID_RELATIONSHIP_TYPES)
_SOURCES = list(VALID_SOURCES)
def _company_id_strategy() -> st.SearchStrategy[str]:
"""Generate valid UUID strings for company IDs."""
return st.uuids().map(str)
def _competitor_relationship_create_strategy() -> st.SearchStrategy[dict[str, Any]]:
"""Generate random valid CompetitorRelationshipCreate field dicts."""
return st.fixed_dictionaries({
"company_b_id": _company_id_strategy(),
"relationship_type": st.sampled_from(_RELATIONSHIP_TYPES),
"strength": st.floats(min_value=0.0, max_value=1.0, allow_nan=False),
"bidirectional": st.booleans(),
"source": st.sampled_from(_SOURCES),
})
def _full_relationship_strategy() -> st.SearchStrategy[dict[str, Any]]:
"""Generate a full CompetitorRelationship dict (as returned from DB)."""
return st.fixed_dictionaries({
"id": _company_id_strategy(),
"company_a_id": _company_id_strategy(),
"company_b_id": _company_id_strategy(),
"relationship_type": st.sampled_from(_RELATIONSHIP_TYPES),
"strength": st.floats(min_value=0.0, max_value=1.0, allow_nan=False),
"bidirectional": st.booleans(),
"source": st.sampled_from(_SOURCES),
"active": st.just(True),
"created_at": st.just(datetime.now(tz=timezone.utc)),
"updated_at": st.just(datetime.now(tz=timezone.utc)),
})
# ---------------------------------------------------------------------------
# Helper: simulate DB round-trip through Pydantic models
# ---------------------------------------------------------------------------
def _simulate_persist_and_read(
company_a_id: str,
create_data: dict[str, Any],
) -> tuple[dict[str, Any], CompetitorRelationship]:
"""Simulate persisting a CompetitorRelationshipCreate to DB and reading back.
We validate the create payload through the Pydantic model, build the
"DB row" dict (as the INSERT ... RETURNING would produce), then parse
it back through the response model. This tests the full Pydantic
round-trip that the real endpoint performs.
"""
# Validate input through the create model
create_model = CompetitorRelationshipCreate(**create_data)
# Simulate the DB row returned by INSERT ... RETURNING
now = datetime.now(tz=timezone.utc)
db_row: dict[str, Any] = {
"id": str(uuid.uuid4()),
"company_a_id": company_a_id,
"company_b_id": create_model.company_b_id,
"relationship_type": create_model.relationship_type,
"strength": create_model.strength,
"bidirectional": create_model.bidirectional,
"source": create_model.source,
"active": True,
"created_at": now,
"updated_at": now,
}
# Parse through the response model (same as endpoint does)
response_model = CompetitorRelationship(**db_row)
return db_row, response_model
# ---------------------------------------------------------------------------
# Property 1: Competitor relationship persistence round-trip
# ---------------------------------------------------------------------------
class TestProperty1CompetitorRelationshipPersistenceRoundTrip:
"""Feature: competitive-historical-patterns, Property 1: Competitor relationship persistence round-trip
For any valid CompetitorRelationship object with valid company IDs,
relationship_type, strength in [0, 1], bidirectional flag, and source,
persisting it to PostgreSQL and reading it back SHALL produce an
equivalent object with all fields preserved.
**Validates: Requirements 1.1, 7.1**
"""
@given(
company_a_id=_company_id_strategy(),
create_data=_competitor_relationship_create_strategy(),
)
@settings(max_examples=100)
def test_round_trip_preserves_all_fields(
self,
company_a_id: str,
create_data: dict[str, Any],
):
"""**Validates: Requirements 1.1, 7.1**
Persisting a CompetitorRelationshipCreate and reading it back through
the response model must preserve every field value.
"""
# Ensure company_a != company_b (DB constraint)
if company_a_id == create_data["company_b_id"]:
return # skip degenerate case; DB would reject this
db_row, response = _simulate_persist_and_read(company_a_id, create_data)
# All fields from the create payload are preserved
assert response.company_a_id == company_a_id
assert response.company_b_id == create_data["company_b_id"]
assert response.relationship_type == create_data["relationship_type"]
assert response.strength == create_data["strength"]
assert response.bidirectional == create_data["bidirectional"]
assert response.source == create_data["source"]
# DB-generated fields are present and valid
assert response.id is not None and len(response.id) > 0
assert response.active is True
assert response.created_at is not None
assert response.updated_at is not None
# Response matches the DB row exactly
assert response.id == db_row["id"]
assert response.created_at == db_row["created_at"]
assert response.updated_at == db_row["updated_at"]
@given(create_data=_competitor_relationship_create_strategy())
@settings(max_examples=100)
def test_create_model_validates_fields(self, create_data: dict[str, Any]):
"""**Validates: Requirements 1.1, 7.1**
The CompetitorRelationshipCreate model must accept all valid
relationship_type and source values, and strength in [0, 1].
"""
model = CompetitorRelationshipCreate(**create_data)
assert model.relationship_type in VALID_RELATIONSHIP_TYPES
assert model.source in VALID_SOURCES
assert 0.0 <= model.strength <= 1.0
assert isinstance(model.bidirectional, bool)
assert isinstance(model.company_b_id, str)
# ---------------------------------------------------------------------------
# Property 2: Competitor query completeness and ordering
# ---------------------------------------------------------------------------
def _build_relationship_row(
company_a_id: str,
company_b_id: str,
strength: float,
active: bool = True,
**overrides: Any,
) -> dict[str, Any]:
"""Build a simulated DB row for a competitor relationship."""
now = datetime.now(tz=timezone.utc)
row = {
"id": str(uuid.uuid4()),
"company_a_id": company_a_id,
"company_b_id": company_b_id,
"relationship_type": "direct_rival",
"strength": strength,
"bidirectional": True,
"source": "manual",
"active": active,
"created_at": now,
"updated_at": now,
}
row.update(overrides)
return row
class TestProperty2CompetitorQueryCompletenessAndOrdering:
"""Feature: competitive-historical-patterns, Property 2: Competitor query completeness and ordering
For any set of competitor relationships involving a company (as either
company_a or company_b), querying competitors for that company SHALL
return all active relationships containing that company, and the results
SHALL be ordered by strength descending.
**Validates: Requirements 1.2**
"""
@given(
target_company=_company_id_strategy(),
strengths=st.lists(
st.floats(min_value=0.0, max_value=1.0, allow_nan=False),
min_size=1,
max_size=15,
),
as_company_a=st.lists(st.booleans(), min_size=1, max_size=15),
)
@settings(max_examples=100)
def test_query_returns_all_active_relationships_sorted_by_strength(
self,
target_company: str,
strengths: list[float],
as_company_a: list[bool],
):
"""**Validates: Requirements 1.2**
All active relationships for a company must be returned, ordered by
strength descending, regardless of whether the company is company_a
or company_b.
"""
# Pad as_company_a to match strengths length
flags = (as_company_a * ((len(strengths) // len(as_company_a)) + 1))[:len(strengths)]
# Build active relationships — some with target as company_a, some as company_b
active_rows: list[dict[str, Any]] = []
inactive_rows: list[dict[str, Any]] = []
for i, (strength, is_a) in enumerate(zip(strengths, flags)):
other = str(uuid.uuid4())
if is_a:
row = _build_relationship_row(target_company, other, strength, active=True)
else:
row = _build_relationship_row(other, target_company, strength, active=True)
active_rows.append(row)
# Add some inactive relationships that should NOT appear
for _ in range(2):
other = str(uuid.uuid4())
inactive_rows.append(
_build_relationship_row(target_company, other, 0.9, active=False)
)
# Simulate the query: filter active rows involving target_company
all_rows = active_rows + inactive_rows
query_result = [
r for r in all_rows
if (r["company_a_id"] == target_company or r["company_b_id"] == target_company)
and r["active"] is True
]
# Sort by strength descending (matching the SQL ORDER BY)
query_result.sort(key=lambda r: r["strength"], reverse=True)
# Parse through response models
results = [CompetitorRelationship(**r) for r in query_result]
# 1. All active relationships are returned
assert len(results) == len(active_rows)
# 2. No inactive relationships are included
inactive_ids = {r["id"] for r in inactive_rows}
for r in results:
assert r.id not in inactive_ids
# 3. Results are ordered by strength descending
for i in range(1, len(results)):
assert results[i - 1].strength >= results[i].strength, (
f"Ordering violated: strength {results[i-1].strength} "
f"should be >= {results[i].strength}"
)
# 4. Every result involves the target company
for r in results:
assert target_company in (r.company_a_id, r.company_b_id)
# ---------------------------------------------------------------------------
# Property 3: Soft-delete preserves row
# ---------------------------------------------------------------------------
class TestProperty3SoftDeletePreservesRow:
"""Feature: competitive-historical-patterns, Property 3: Soft-delete preserves row
For any active competitor relationship, deleting it SHALL set
active = False while preserving the row in the database with all
original field values intact.
**Validates: Requirements 1.3**
"""
@given(rel=_full_relationship_strategy())
@settings(max_examples=100)
def test_soft_delete_sets_active_false_preserves_fields(
self,
rel: dict[str, Any],
):
"""**Validates: Requirements 1.3**
After soft-delete, the row must still exist with active=False and
all original field values (id, company_a_id, company_b_id,
relationship_type, strength, bidirectional, source, created_at)
preserved.
"""
# Snapshot the original state before deletion
original = copy.deepcopy(rel)
assert original["active"] is True
# Simulate the soft-delete UPDATE (matches the DELETE endpoint SQL)
rel["active"] = False
rel["updated_at"] = datetime.now(tz=timezone.utc)
# The row still exists
assert rel is not None
# active is now False
assert rel["active"] is False
# All original fields are preserved (except active and updated_at)
assert rel["id"] == original["id"]
assert rel["company_a_id"] == original["company_a_id"]
assert rel["company_b_id"] == original["company_b_id"]
assert rel["relationship_type"] == original["relationship_type"]
assert rel["strength"] == original["strength"]
assert rel["bidirectional"] == original["bidirectional"]
assert rel["source"] == original["source"]
assert rel["created_at"] == original["created_at"]
# updated_at has changed (soft-delete updates the timestamp)
assert rel["updated_at"] >= original["updated_at"]
@given(rel=_full_relationship_strategy())
@settings(max_examples=100)
def test_soft_deleted_row_excluded_from_active_queries(
self,
rel: dict[str, Any],
):
"""**Validates: Requirements 1.3**
After soft-delete, the relationship must not appear in queries
filtered by active = TRUE, but the row data is still intact.
"""
original = copy.deepcopy(rel)
# Soft-delete
rel["active"] = False
rel["updated_at"] = datetime.now(tz=timezone.utc)
# Simulate active-only query filter (WHERE active = TRUE)
all_rows = [rel]
active_results = [r for r in all_rows if r["active"] is True]
# Soft-deleted row is excluded from active queries
assert len(active_results) == 0
# But the row still exists in the full table
all_results = [r for r in all_rows]
assert len(all_results) == 1
# And all original data is preserved
preserved = all_results[0]
assert preserved["id"] == original["id"]
assert preserved["company_a_id"] == original["company_a_id"]
assert preserved["company_b_id"] == original["company_b_id"]
assert preserved["relationship_type"] == original["relationship_type"]
assert preserved["strength"] == original["strength"]
assert preserved["bidirectional"] == original["bidirectional"]
assert preserved["source"] == original["source"]
# ---------------------------------------------------------------------------
# Helpers for auto-inference property tests (Properties 46)
# ---------------------------------------------------------------------------
# Pure reimplementation of the inference strength formula from
# services/symbol_registry/competitor_inference.py so we can test the
# algorithm's properties without touching the DB.
def _compute_inference_strength(co_count: int, max_count: int) -> float:
"""Compute inferred relationship strength.
Formula: 0.3 * sector_match + 0.7 * normalized_co_mention_count
sector_match is always 1.0 because candidates are pre-filtered by
sector AND industry.
"""
if max_count <= 0:
max_count = 1
normalized = co_count / max_count
return 0.3 * 1.0 + 0.7 * normalized
def _run_inference_simulation(
company_id: str,
candidate_ids: list[str],
co_mention_counts: dict[str, int],
) -> list[dict[str, Any]]:
"""Simulate the auto-inference algorithm (pure, no DB).
Mirrors the logic in ``infer_competitors``:
1. All candidates share the same sector/industry (pre-filtered).
2. Compute max co-mention count across candidates.
3. Compute strength for each candidate.
4. Build relationship dicts with source='inferred'.
5. Sort by strength descending.
"""
if not candidate_ids:
return []
max_count = max((co_mention_counts.get(cid, 0) for cid in candidate_ids), default=1)
if max_count == 0:
max_count = 1
results: list[dict[str, Any]] = []
now = datetime.now(tz=timezone.utc)
for cid in candidate_ids:
co_count = co_mention_counts.get(cid, 0)
strength = _compute_inference_strength(co_count, max_count)
a_id = min(company_id, cid)
b_id = max(company_id, cid)
results.append({
"id": str(uuid.uuid4()),
"company_a_id": a_id,
"company_b_id": b_id,
"relationship_type": "same_sector",
"strength": strength,
"bidirectional": True,
"source": "inferred",
"active": True,
"created_at": now,
"updated_at": now,
})
results.sort(key=lambda r: r["strength"], reverse=True)
return results
# Strategies for auto-inference tests
def _sector_industry_strategy() -> st.SearchStrategy[str]:
"""Generate a sector/industry label."""
return st.sampled_from([
"Technology", "Healthcare", "Finance", "Energy",
"Consumer", "Industrial", "Materials", "Utilities",
])
def _co_mention_count_strategy() -> st.SearchStrategy[int]:
"""Generate a non-negative co-mention count."""
return st.integers(min_value=0, max_value=500)
# ---------------------------------------------------------------------------
# Property 4: Auto-inference produces valid candidates
# ---------------------------------------------------------------------------
class TestProperty4AutoInferenceProducesValidCandidates:
"""Feature: competitive-historical-patterns, Property 4: Auto-inference produces valid candidates
For any company with a defined sector and industry, running
auto-inference SHALL produce only candidate relationships where the
candidate company shares the same sector and industry, and all
produced relationships SHALL have source = 'inferred' with strength
in [0, 1].
**Validates: Requirements 2.1, 2.3**
"""
@given(
company_id=_company_id_strategy(),
num_candidates=st.integers(min_value=1, max_value=20),
co_counts=st.lists(
_co_mention_count_strategy(), min_size=1, max_size=20,
),
)
@settings(max_examples=100)
def test_all_inferred_relationships_have_valid_source_and_strength(
self,
company_id: str,
num_candidates: int,
co_counts: list[int],
):
"""**Validates: Requirements 2.1, 2.3**
Every inferred relationship must have source='inferred' and
strength in [0.3, 1.0] (since sector_match is always 1.0 for
filtered candidates, the minimum is 0.3*1.0 + 0.7*0 = 0.3).
"""
# Generate unique candidate IDs distinct from company_id
candidate_ids = [str(uuid.uuid4()) for _ in range(num_candidates)]
# Pad co_counts to match candidates
padded = (co_counts * ((num_candidates // len(co_counts)) + 1))[:num_candidates]
co_mention_map = dict(zip(candidate_ids, padded))
results = _run_inference_simulation(company_id, candidate_ids, co_mention_map)
assert len(results) == num_candidates
for rel in results:
# Source must be 'inferred'
assert rel["source"] == "inferred", (
f"Expected source='inferred', got '{rel['source']}'"
)
# Strength must be in [0, 1] (general contract)
assert 0.0 <= rel["strength"] <= 1.0, (
f"Strength {rel['strength']} out of [0, 1]"
)
# More specifically, since sector_match=1.0, minimum is 0.3
assert rel["strength"] >= 0.3 - 1e-9, (
f"Strength {rel['strength']} below theoretical minimum 0.3"
)
# Relationship type must be same_sector
assert rel["relationship_type"] == "same_sector"
# Bidirectional must be True
assert rel["bidirectional"] is True
# Active must be True
assert rel["active"] is True
@given(
company_id=_company_id_strategy(),
co_count=_co_mention_count_strategy(),
max_count=st.integers(min_value=1, max_value=1000),
)
@settings(max_examples=100)
def test_strength_formula_always_in_valid_range(
self,
company_id: str,
co_count: int,
max_count: int,
):
"""**Validates: Requirements 2.1, 2.3**
The strength formula 0.3 * 1.0 + 0.7 * (co_count / max_count)
must always produce a value in [0.3, 1.0] when co_count <= max_count.
"""
# Clamp co_count to not exceed max_count for realistic input
clamped = min(co_count, max_count)
strength = _compute_inference_strength(clamped, max_count)
assert 0.3 - 1e-9 <= strength <= 1.0 + 1e-9, (
f"Strength {strength} outside [0.3, 1.0] for "
f"co_count={clamped}, max_count={max_count}"
)
@given(company_id=_company_id_strategy())
@settings(max_examples=100)
def test_empty_candidates_returns_empty(self, company_id: str):
"""**Validates: Requirements 2.1, 2.3**
When no candidates share the same sector/industry, inference
returns an empty list.
"""
results = _run_inference_simulation(company_id, [], {})
assert results == []
# ---------------------------------------------------------------------------
# Property 5: Auto-inference ranks by co-mention frequency
# ---------------------------------------------------------------------------
class TestProperty5AutoInferenceRanksByCoMentionFrequency:
"""Feature: competitive-historical-patterns, Property 5: Auto-inference ranks by co-mention frequency
For any set of candidate competitors with different co-mention counts
in document_company_mentions, the auto-inferred relationships SHALL
have strength scores that are monotonically non-decreasing with
co-mention frequency — candidates with more co-mentions receive
higher or equal strength scores.
**Validates: Requirements 2.2**
"""
@given(
company_id=_company_id_strategy(),
co_counts=st.lists(
_co_mention_count_strategy(), min_size=2, max_size=20,
),
)
@settings(max_examples=100)
def test_higher_co_mentions_yield_higher_or_equal_strength(
self,
company_id: str,
co_counts: list[int],
):
"""**Validates: Requirements 2.2**
When we sort candidates by co-mention count ascending, their
computed strengths must also be non-decreasing.
"""
candidate_ids = [str(uuid.uuid4()) for _ in range(len(co_counts))]
co_mention_map = dict(zip(candidate_ids, co_counts))
# Compute strengths using the same normalization as the real code
max_count = max(co_counts) if co_counts else 1
if max_count == 0:
max_count = 1
# Build (co_count, strength) pairs
pairs = []
for cid, count in zip(candidate_ids, co_counts):
strength = _compute_inference_strength(count, max_count)
pairs.append((count, strength))
# Sort by co-mention count ascending
pairs.sort(key=lambda p: p[0])
# Strengths must be monotonically non-decreasing
for i in range(1, len(pairs)):
assert pairs[i][1] >= pairs[i - 1][1] - 1e-9, (
f"Monotonicity violated: co_count {pairs[i][0]} has strength "
f"{pairs[i][1]} < co_count {pairs[i-1][0]} strength {pairs[i-1][1]}"
)
@given(
company_id=_company_id_strategy(),
low_count=st.integers(min_value=0, max_value=100),
high_count=st.integers(min_value=101, max_value=500),
)
@settings(max_examples=100)
def test_strictly_more_co_mentions_never_lower_strength(
self,
company_id: str,
low_count: int,
high_count: int,
):
"""**Validates: Requirements 2.2**
Given two candidates where one has strictly more co-mentions,
the one with more co-mentions must have >= strength.
"""
max_count = high_count # high_count is the max
low_strength = _compute_inference_strength(low_count, max_count)
high_strength = _compute_inference_strength(high_count, max_count)
assert high_strength >= low_strength - 1e-9, (
f"Candidate with {high_count} co-mentions has strength "
f"{high_strength} < candidate with {low_count} co-mentions "
f"strength {low_strength}"
)
# ---------------------------------------------------------------------------
# Property 6: Auto-inference idempotence
# ---------------------------------------------------------------------------
class TestProperty6AutoInferenceIdempotence:
"""Feature: competitive-historical-patterns, Property 6: Auto-inference idempotence
For any company, running auto-inference twice in succession SHALL
produce the same set of relationships (no duplicates created), with
strength scores updated to reflect the latest co-mention data.
**Validates: Requirements 2.4**
"""
@given(
company_id=_company_id_strategy(),
co_counts=st.lists(
_co_mention_count_strategy(), min_size=1, max_size=15,
),
)
@settings(max_examples=100)
def test_two_runs_produce_identical_results(
self,
company_id: str,
co_counts: list[int],
):
"""**Validates: Requirements 2.4**
Running inference twice with the same co-mention data must
produce the exact same set of relationships with the same
strengths — no duplicates, no missing entries.
"""
candidate_ids = [str(uuid.uuid4()) for _ in range(len(co_counts))]
co_mention_map = dict(zip(candidate_ids, co_counts))
run1 = _run_inference_simulation(company_id, candidate_ids, co_mention_map)
run2 = _run_inference_simulation(company_id, candidate_ids, co_mention_map)
# Same number of relationships
assert len(run1) == len(run2), (
f"Run 1 produced {len(run1)} relationships, run 2 produced {len(run2)}"
)
# Same company pairs (by sorted (a, b) tuples)
pairs1 = sorted((r["company_a_id"], r["company_b_id"]) for r in run1)
pairs2 = sorted((r["company_a_id"], r["company_b_id"]) for r in run2)
assert pairs1 == pairs2, "Company pairs differ between runs"
# Same strengths for each pair
strength_map1 = {
(r["company_a_id"], r["company_b_id"]): r["strength"] for r in run1
}
strength_map2 = {
(r["company_a_id"], r["company_b_id"]): r["strength"] for r in run2
}
for pair in strength_map1:
assert abs(strength_map1[pair] - strength_map2[pair]) < 1e-9, (
f"Strength mismatch for pair {pair}: "
f"{strength_map1[pair]} vs {strength_map2[pair]}"
)
@given(
company_id=_company_id_strategy(),
co_counts=st.lists(
_co_mention_count_strategy(), min_size=1, max_size=15,
),
)
@settings(max_examples=100)
def test_no_duplicate_pairs_in_single_run(
self,
company_id: str,
co_counts: list[int],
):
"""**Validates: Requirements 2.4**
A single inference run must never produce duplicate company
pairs — the upsert logic ensures at most one active relationship
per (company_a, company_b) pair.
"""
candidate_ids = [str(uuid.uuid4()) for _ in range(len(co_counts))]
co_mention_map = dict(zip(candidate_ids, co_counts))
results = _run_inference_simulation(company_id, candidate_ids, co_mention_map)
pairs = [(r["company_a_id"], r["company_b_id"]) for r in results]
assert len(pairs) == len(set(pairs)), (
f"Duplicate pairs found: {len(pairs)} total, {len(set(pairs))} unique"
)
@given(
company_id=_company_id_strategy(),
initial_counts=st.lists(
_co_mention_count_strategy(), min_size=2, max_size=10,
),
updated_counts=st.lists(
_co_mention_count_strategy(), min_size=2, max_size=10,
),
)
@settings(max_examples=100)
def test_re_inference_updates_strengths_to_latest_data(
self,
company_id: str,
initial_counts: list[int],
updated_counts: list[int],
):
"""**Validates: Requirements 2.4**
When co-mention data changes between inference runs, the second
run must produce strengths reflecting the updated data, not the
original data.
"""
# Use the shorter list length to keep candidates consistent
n = min(len(initial_counts), len(updated_counts))
candidate_ids = [str(uuid.uuid4()) for _ in range(n)]
initial_map = dict(zip(candidate_ids, initial_counts[:n]))
updated_map = dict(zip(candidate_ids, updated_counts[:n]))
run1 = _run_inference_simulation(company_id, candidate_ids, initial_map)
run2 = _run_inference_simulation(company_id, candidate_ids, updated_map)
# Same set of company pairs
pairs1 = sorted((r["company_a_id"], r["company_b_id"]) for r in run1)
pairs2 = sorted((r["company_a_id"], r["company_b_id"]) for r in run2)
assert pairs1 == pairs2, "Company pairs should be identical across re-inference"
# Strengths in run2 must match the updated co-mention data
max_updated = max(updated_counts[:n]) if updated_counts[:n] else 1
if max_updated == 0:
max_updated = 1
for rel in run2:
# Find which candidate this is
other_id = (
rel["company_b_id"]
if rel["company_a_id"] == min(company_id, rel["company_b_id"])
and rel["company_b_id"] != company_id
else rel["company_a_id"]
)
# Determine the candidate id from our list
for cid in candidate_ids:
a = min(company_id, cid)
b = max(company_id, cid)
if a == rel["company_a_id"] and b == rel["company_b_id"]:
expected = _compute_inference_strength(
updated_map[cid], max_updated
)
assert abs(rel["strength"] - expected) < 1e-9, (
f"Strength {rel['strength']} != expected {expected} "
f"for updated co_count={updated_map[cid]}"
)
break