242 lines
8.3 KiB
Python
242 lines
8.3 KiB
Python
"""Resilient adapter wrapper with rate-limit coordination, retries, and backoff.
|
|
|
|
Wraps any BaseAdapter with:
|
|
- Per-source-type rate limiting via Redis (distributed across workers)
|
|
- Exponential backoff with jitter on retryable failures
|
|
- Configurable retry counts and retryable HTTP status codes
|
|
- Graceful degradation when Redis is unavailable
|
|
|
|
Requirements: 2.5, 3.4
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
import random
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import redis.asyncio as aioredis
|
|
|
|
from services.shared.redis_keys import rate_limit_key
|
|
|
|
from .base import AdapterResult, BaseAdapter
|
|
|
|
logger = logging.getLogger("resilient_adapter")
|
|
|
|
# HTTP status codes that are safe to retry
|
|
RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
|
|
|
|
|
|
@dataclass
|
|
class RetryConfig:
|
|
"""Configuration for retry and rate-limit behavior."""
|
|
|
|
max_retries: int = 3
|
|
base_delay: float = 1.0
|
|
max_delay: float = 60.0
|
|
jitter_factor: float = 0.5
|
|
retryable_status_codes: frozenset[int] = RETRYABLE_STATUS_CODES
|
|
# Rate limit: max requests per window per source type
|
|
rate_limit_max: int = 30
|
|
rate_limit_window_seconds: int = 60
|
|
|
|
|
|
# Sensible defaults per source type
|
|
DEFAULT_RETRY_CONFIGS: dict[str, RetryConfig] = {
|
|
"market_api": RetryConfig(max_retries=3, rate_limit_max=30),
|
|
"news_api": RetryConfig(max_retries=3, rate_limit_max=20),
|
|
"filings_api": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0),
|
|
"web_scrape": RetryConfig(max_retries=2, rate_limit_max=10, base_delay=2.0),
|
|
"broker": RetryConfig(max_retries=2, rate_limit_max=60, base_delay=0.5),
|
|
}
|
|
|
|
|
|
def compute_delay(attempt: int, config: RetryConfig) -> float:
|
|
"""Compute backoff delay with jitter for a given attempt number."""
|
|
exp_delay = config.base_delay * (2 ** attempt)
|
|
capped = min(exp_delay, config.max_delay)
|
|
jitter = capped * config.jitter_factor * random.random()
|
|
return capped + jitter
|
|
|
|
|
|
|
|
@dataclass
|
|
class RetryStats:
|
|
"""Tracks retry statistics for observability."""
|
|
|
|
attempts: int = 0
|
|
total_delay: float = 0.0
|
|
rate_limited_waits: int = 0
|
|
last_error: str | None = None
|
|
retryable: bool = False
|
|
|
|
|
|
class ResilientAdapter:
|
|
"""Wraps a BaseAdapter with rate-limit coordination, retries, and backoff.
|
|
|
|
Usage:
|
|
adapter = PolygonMarketAdapter(api_key="...")
|
|
resilient = ResilientAdapter(adapter, redis=rds)
|
|
result = await resilient.fetch(ticker, config)
|
|
|
|
If redis is None, rate limiting is skipped (local dev / testing).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
adapter: BaseAdapter,
|
|
redis: aioredis.Redis | None = None,
|
|
retry_config: RetryConfig | None = None,
|
|
) -> None:
|
|
self._adapter = adapter
|
|
self._redis = redis
|
|
source_type = adapter.source_type()
|
|
self._config = retry_config or DEFAULT_RETRY_CONFIGS.get(
|
|
source_type, RetryConfig()
|
|
)
|
|
|
|
@property
|
|
def adapter(self) -> BaseAdapter:
|
|
"""Access the underlying adapter."""
|
|
return self._adapter
|
|
|
|
@property
|
|
def config(self) -> RetryConfig:
|
|
return self._config
|
|
|
|
def source_type(self) -> str:
|
|
return self._adapter.source_type()
|
|
|
|
async def _check_rate_limit(self) -> float:
|
|
"""Check distributed rate limit via Redis.
|
|
|
|
Returns 0.0 if allowed, or the number of seconds to wait.
|
|
"""
|
|
if self._redis is None:
|
|
return 0.0
|
|
|
|
source_type = self._adapter.source_type()
|
|
window_sec = self._config.rate_limit_window_seconds
|
|
# Use a time-bucketed key so counters auto-expire
|
|
bucket = int(time.time()) // window_sec
|
|
key = rate_limit_key(source_type, str(bucket))
|
|
|
|
try:
|
|
count = await self._redis.incr(key)
|
|
if count == 1:
|
|
await self._redis.expire(key, window_sec * 2)
|
|
if count > self._config.rate_limit_max:
|
|
# Over limit — compute how long until the window rolls over
|
|
elapsed_in_window = time.time() % window_sec
|
|
wait = window_sec - elapsed_in_window
|
|
return max(wait, 0.5)
|
|
except Exception:
|
|
# Redis unavailable — degrade gracefully, allow the request
|
|
logger.warning("Redis rate-limit check failed, allowing request")
|
|
return 0.0
|
|
|
|
def _is_retryable(self, result: AdapterResult) -> bool:
|
|
"""Determine if a failed result is worth retrying."""
|
|
if result.ok:
|
|
return False
|
|
# Retry on known retryable HTTP status codes
|
|
if result.http_status and result.http_status in self._config.retryable_status_codes:
|
|
return True
|
|
# Retry on timeouts
|
|
if result.error and "timeout" in result.error.lower():
|
|
return True
|
|
# Retry on connection errors
|
|
if result.error and any(
|
|
kw in result.error.lower()
|
|
for kw in ("connection", "connect", "reset", "refused")
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def _extract_retry_after(self, result: AdapterResult) -> float | None:
|
|
"""Extract Retry-After hint from result metadata if present."""
|
|
retry_after = result.metadata.get("retry_after")
|
|
if retry_after is not None:
|
|
try:
|
|
return float(retry_after)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
return None
|
|
|
|
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
|
"""Fetch with rate-limit coordination, retries, and exponential backoff.
|
|
|
|
Returns the AdapterResult from the underlying adapter. On retryable
|
|
failures, retries up to max_retries times with exponential backoff
|
|
and jitter. Rate-limit waits are applied before each attempt.
|
|
|
|
The returned result's metadata includes retry stats under the
|
|
"retry_stats" key.
|
|
"""
|
|
stats = RetryStats()
|
|
last_result: AdapterResult | None = None
|
|
|
|
for attempt in range(self._config.max_retries + 1):
|
|
stats.attempts = attempt + 1
|
|
|
|
# Rate limit check
|
|
wait = await self._check_rate_limit()
|
|
if wait > 0:
|
|
stats.rate_limited_waits += 1
|
|
logger.info(
|
|
"Rate limited for %s/%s, waiting %.1fs",
|
|
self.source_type(), ticker, wait,
|
|
)
|
|
stats.total_delay += wait
|
|
await asyncio.sleep(wait)
|
|
|
|
# Execute the fetch
|
|
result = await self._adapter.fetch(ticker, config)
|
|
last_result = result
|
|
|
|
# Success — attach stats and return
|
|
if result.ok:
|
|
result.metadata["retry_stats"] = {
|
|
"attempts": stats.attempts,
|
|
"total_delay": round(stats.total_delay, 2),
|
|
"rate_limited_waits": stats.rate_limited_waits,
|
|
}
|
|
return result
|
|
|
|
# Check if retryable
|
|
if not self._is_retryable(result):
|
|
stats.last_error = result.error
|
|
stats.retryable = False
|
|
break
|
|
|
|
stats.retryable = True
|
|
stats.last_error = result.error
|
|
|
|
# Don't sleep after the last attempt
|
|
if attempt < self._config.max_retries:
|
|
# Respect Retry-After header for 429s
|
|
retry_after = self._extract_retry_after(result)
|
|
if result.http_status == 429 and retry_after is not None:
|
|
delay = min(retry_after, self._config.max_delay)
|
|
else:
|
|
delay = compute_delay(attempt, self._config)
|
|
|
|
logger.info(
|
|
"Retrying %s/%s (attempt %d/%d) after %.1fs: %s",
|
|
self.source_type(), ticker, attempt + 1,
|
|
self._config.max_retries + 1, delay, result.error,
|
|
)
|
|
stats.total_delay += delay
|
|
await asyncio.sleep(delay)
|
|
|
|
# All retries exhausted — return last result with stats
|
|
assert last_result is not None
|
|
last_result.metadata["retry_stats"] = {
|
|
"attempts": stats.attempts,
|
|
"total_delay": round(stats.total_delay, 2),
|
|
"rate_limited_waits": stats.rate_limited_waits,
|
|
"exhausted": True,
|
|
"last_error": stats.last_error,
|
|
}
|
|
return last_result
|