feat: sell execution, correlation matrix from market data, US market holiday awareness
- Sell path: looks up existing position, sells full quantity, returns proceeds to pool - Correlation matrix: computed from 30-day market_snapshots on startup + every 5min - Holidays: 10 major US market holidays for 2026 checked in trading window functions
This commit is contained in:
@@ -22,6 +22,13 @@ from datetime import datetime, timedelta, timezone
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np # noqa: F401
|
||||||
|
|
||||||
|
_HAS_NUMPY = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_NUMPY = False
|
||||||
|
|
||||||
from services.shared.config import TradingConfig
|
from services.shared.config import TradingConfig
|
||||||
from services.shared.redis_keys import (
|
from services.shared.redis_keys import (
|
||||||
QUEUE_BROKER,
|
QUEUE_BROKER,
|
||||||
@@ -588,6 +595,9 @@ class TradingEngine:
|
|||||||
initial_capital, invested, available, reserve_balance, open_count,
|
initial_capital, invested, available, reserve_balance, open_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute initial correlation matrix from market data
|
||||||
|
await self._compute_correlation_matrix()
|
||||||
|
|
||||||
async def _decision_loop(self) -> None:
|
async def _decision_loop(self) -> None:
|
||||||
"""Poll recommendations and evaluate them in a continuous loop.
|
"""Poll recommendations and evaluate them in a continuous loop.
|
||||||
|
|
||||||
@@ -679,6 +689,57 @@ class TradingEngine:
|
|||||||
if self.portfolio_state is None:
|
if self.portfolio_state is None:
|
||||||
self.portfolio_state = PortfolioState()
|
self.portfolio_state = PortfolioState()
|
||||||
|
|
||||||
|
action = rec.get("action", "buy")
|
||||||
|
|
||||||
|
# --- Sell path: skip position sizing, look up existing position ---
|
||||||
|
if action == "sell":
|
||||||
|
pos_row = None
|
||||||
|
try:
|
||||||
|
pos_row = await self.pool.fetchrow(
|
||||||
|
"SELECT quantity, avg_entry_price, current_price "
|
||||||
|
"FROM positions WHERE ticker = $1 AND quantity > 0",
|
||||||
|
ticker,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not look up position for sell: %s", ticker)
|
||||||
|
|
||||||
|
if pos_row is None:
|
||||||
|
logger.info("Sell recommendation for %s but no open position — skipping", ticker)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sell_qty = int(pos_row["quantity"])
|
||||||
|
sell_price = rec.get("current_price", 0.0)
|
||||||
|
estimated_proceeds = sell_qty * sell_price
|
||||||
|
|
||||||
|
order_job = {
|
||||||
|
"trading_decision_id": str(uuid.uuid4()),
|
||||||
|
"ticker": ticker,
|
||||||
|
"action": "sell",
|
||||||
|
"quantity": sell_qty,
|
||||||
|
"order_type": "market",
|
||||||
|
"source": "trading_engine",
|
||||||
|
}
|
||||||
|
if self.redis is not None:
|
||||||
|
broker_queue = queue_key(QUEUE_BROKER)
|
||||||
|
await self.redis.rpush(broker_queue, json.dumps(order_job))
|
||||||
|
logger.info(
|
||||||
|
"Pushed sell order for %s (%d shares, ~$%.2f) to broker queue",
|
||||||
|
ticker, sell_qty, estimated_proceeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update portfolio state
|
||||||
|
if self.portfolio_state:
|
||||||
|
self.portfolio_state.open_position_count = max(
|
||||||
|
0, self.portfolio_state.open_position_count - 1
|
||||||
|
)
|
||||||
|
self.portfolio_state.active_pool += estimated_proceeds
|
||||||
|
|
||||||
|
# Mark as processed
|
||||||
|
if rec_id:
|
||||||
|
self.processed_recommendation_ids.add(rec_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- Buy path: evaluate recommendation through position sizer ---
|
||||||
# Evaluate recommendation
|
# Evaluate recommendation
|
||||||
decision = self.evaluate_recommendation(
|
decision = self.evaluate_recommendation(
|
||||||
rec=rec,
|
rec=rec,
|
||||||
@@ -888,6 +949,12 @@ class TradingEngine:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Could not compute performance metrics")
|
logger.debug("Could not compute performance metrics")
|
||||||
|
|
||||||
|
# Refresh correlation matrix every 5 minutes
|
||||||
|
try:
|
||||||
|
await self._compute_correlation_matrix()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not refresh correlation matrix")
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -1164,6 +1231,135 @@ class TradingEngine:
|
|||||||
if self.running:
|
if self.running:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Correlation matrix computation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _compute_correlation_matrix(self) -> None:
|
||||||
|
"""Compute pairwise price correlations from market_snapshots and load into self.correlation_matrix.
|
||||||
|
|
||||||
|
Queries the last 30 days of daily close prices, computes daily returns,
|
||||||
|
then calculates Pearson correlation coefficients between each ticker pair.
|
||||||
|
Uses numpy when available, otherwise falls back to a manual computation.
|
||||||
|
"""
|
||||||
|
if self.pool is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = await self.pool.fetch(
|
||||||
|
"SELECT ticker, captured_at::date AS dt, (data->>'c')::float AS close "
|
||||||
|
"FROM market_snapshots "
|
||||||
|
"WHERE snapshot_type = 'bar' AND captured_at > NOW() - INTERVAL '30 days' "
|
||||||
|
"ORDER BY ticker, captured_at"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not query market_snapshots for correlation matrix")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Group close prices by ticker, keyed by date
|
||||||
|
ticker_prices: dict[str, dict] = {}
|
||||||
|
for row in rows:
|
||||||
|
ticker = row["ticker"]
|
||||||
|
dt = row["dt"]
|
||||||
|
close = row["close"]
|
||||||
|
if close is None:
|
||||||
|
continue
|
||||||
|
if ticker not in ticker_prices:
|
||||||
|
ticker_prices[ticker] = {}
|
||||||
|
ticker_prices[ticker][dt] = close
|
||||||
|
|
||||||
|
# Compute daily returns for each ticker
|
||||||
|
ticker_returns: dict[str, list[float]] = {}
|
||||||
|
all_dates: set = set()
|
||||||
|
for ticker, prices_by_date in ticker_prices.items():
|
||||||
|
sorted_dates = sorted(prices_by_date.keys())
|
||||||
|
all_dates.update(sorted_dates)
|
||||||
|
returns = []
|
||||||
|
for i in range(1, len(sorted_dates)):
|
||||||
|
prev = prices_by_date[sorted_dates[i - 1]]
|
||||||
|
curr = prices_by_date[sorted_dates[i]]
|
||||||
|
if prev > 0:
|
||||||
|
returns.append((curr - prev) / prev)
|
||||||
|
if returns:
|
||||||
|
ticker_returns[ticker] = returns
|
||||||
|
|
||||||
|
tickers = list(ticker_returns.keys())
|
||||||
|
if len(tickers) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Align returns to common dates for proper pairwise comparison
|
||||||
|
sorted_all_dates = sorted(all_dates)
|
||||||
|
aligned_returns: dict[str, list[float]] = {}
|
||||||
|
for ticker in tickers:
|
||||||
|
prices_by_date = ticker_prices[ticker]
|
||||||
|
aligned = []
|
||||||
|
for i in range(1, len(sorted_all_dates)):
|
||||||
|
prev_dt = sorted_all_dates[i - 1]
|
||||||
|
curr_dt = sorted_all_dates[i]
|
||||||
|
if prev_dt in prices_by_date and curr_dt in prices_by_date:
|
||||||
|
prev_p = prices_by_date[prev_dt]
|
||||||
|
curr_p = prices_by_date[curr_dt]
|
||||||
|
if prev_p > 0:
|
||||||
|
aligned.append((curr_p - prev_p) / prev_p)
|
||||||
|
else:
|
||||||
|
aligned.append(0.0)
|
||||||
|
else:
|
||||||
|
aligned.append(None) # type: ignore[arg-type]
|
||||||
|
aligned_returns[ticker] = aligned
|
||||||
|
|
||||||
|
corr_data: dict[tuple[str, str], float] = {}
|
||||||
|
|
||||||
|
if _HAS_NUMPY:
|
||||||
|
import numpy as _np
|
||||||
|
|
||||||
|
for i in range(len(tickers)):
|
||||||
|
for j in range(i + 1, len(tickers)):
|
||||||
|
a_raw = aligned_returns[tickers[i]]
|
||||||
|
b_raw = aligned_returns[tickers[j]]
|
||||||
|
# Use only indices where both have valid returns
|
||||||
|
pairs = [
|
||||||
|
(a_raw[k], b_raw[k])
|
||||||
|
for k in range(len(a_raw))
|
||||||
|
if a_raw[k] is not None and b_raw[k] is not None
|
||||||
|
]
|
||||||
|
if len(pairs) < 5:
|
||||||
|
continue
|
||||||
|
a_arr = _np.array([p[0] for p in pairs])
|
||||||
|
b_arr = _np.array([p[1] for p in pairs])
|
||||||
|
corr_matrix = _np.corrcoef(a_arr, b_arr)
|
||||||
|
corr_val = float(corr_matrix[0, 1])
|
||||||
|
if not _np.isnan(corr_val):
|
||||||
|
corr_data[(tickers[i], tickers[j])] = corr_val
|
||||||
|
else:
|
||||||
|
# Manual Pearson correlation fallback
|
||||||
|
for i in range(len(tickers)):
|
||||||
|
for j in range(i + 1, len(tickers)):
|
||||||
|
a_raw = aligned_returns[tickers[i]]
|
||||||
|
b_raw = aligned_returns[tickers[j]]
|
||||||
|
pairs = [
|
||||||
|
(a_raw[k], b_raw[k])
|
||||||
|
for k in range(len(a_raw))
|
||||||
|
if a_raw[k] is not None and b_raw[k] is not None
|
||||||
|
]
|
||||||
|
if len(pairs) < 5:
|
||||||
|
continue
|
||||||
|
a_vals = [p[0] for p in pairs]
|
||||||
|
b_vals = [p[1] for p in pairs]
|
||||||
|
n = len(a_vals)
|
||||||
|
mean_a = sum(a_vals) / n
|
||||||
|
mean_b = sum(b_vals) / n
|
||||||
|
cov = sum((a_vals[k] - mean_a) * (b_vals[k] - mean_b) for k in range(n))
|
||||||
|
std_a = sum((v - mean_a) ** 2 for v in a_vals) ** 0.5
|
||||||
|
std_b = sum((v - mean_b) ** 2 for v in b_vals) ** 0.5
|
||||||
|
if std_a > 0 and std_b > 0:
|
||||||
|
corr_data[(tickers[i], tickers[j])] = cov / (std_a * std_b)
|
||||||
|
|
||||||
|
self.correlation_matrix.load(corr_data)
|
||||||
|
logger.info("Correlation matrix loaded: %d pairs", len(corr_data))
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Async helpers
|
# Async helpers
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ within the allowed trading window (9:45 AM – 3:45 PM ET on weekdays),
|
|||||||
whether the US market is open, and when the next trading window opens.
|
whether the US market is open, and when the next trading window opens.
|
||||||
|
|
||||||
Uses ``zoneinfo.ZoneInfo("America/New_York")`` for Eastern Time handling.
|
Uses ``zoneinfo.ZoneInfo("America/New_York")`` for Eastern Time handling.
|
||||||
Does not check market holidays (simplified).
|
Checks major US market holidays for 2026.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, time, timedelta
|
from datetime import date, datetime, time, timedelta
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
# US Eastern timezone
|
# US Eastern timezone
|
||||||
@@ -28,16 +28,49 @@ MARKET_CLOSE = time(16, 0)
|
|||||||
_WEEKDAYS = range(0, 5)
|
_WEEKDAYS = range(0, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def _us_market_holidays_2026() -> set[date]:
|
||||||
|
"""Return a set of US market holiday dates for 2026.
|
||||||
|
|
||||||
|
Major holidays observed by NYSE/NASDAQ:
|
||||||
|
- New Year's Day (Jan 1)
|
||||||
|
- MLK Day (3rd Monday of January)
|
||||||
|
- Presidents' Day (3rd Monday of February)
|
||||||
|
- Good Friday (April 3)
|
||||||
|
- Memorial Day (last Monday of May)
|
||||||
|
- Juneteenth (June 19)
|
||||||
|
- Independence Day (July 3 observed — July 4 is Saturday)
|
||||||
|
- Labor Day (1st Monday of September)
|
||||||
|
- Thanksgiving (4th Thursday of November)
|
||||||
|
- Christmas (Dec 25)
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
date(2026, 1, 1), # New Year's Day
|
||||||
|
date(2026, 1, 19), # MLK Day (3rd Monday)
|
||||||
|
date(2026, 2, 16), # Presidents' Day (3rd Monday)
|
||||||
|
date(2026, 4, 3), # Good Friday
|
||||||
|
date(2026, 5, 25), # Memorial Day (last Monday)
|
||||||
|
date(2026, 6, 19), # Juneteenth
|
||||||
|
date(2026, 7, 3), # Independence Day (observed)
|
||||||
|
date(2026, 9, 7), # Labor Day (1st Monday)
|
||||||
|
date(2026, 11, 26), # Thanksgiving (4th Thursday)
|
||||||
|
date(2026, 12, 25), # Christmas
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_HOLIDAYS_2026 = _us_market_holidays_2026()
|
||||||
|
|
||||||
|
|
||||||
def is_within_trading_window(dt: datetime) -> bool:
|
def is_within_trading_window(dt: datetime) -> bool:
|
||||||
"""Return True if *dt* is between 9:45 AM ET and 3:45 PM ET on a weekday.
|
"""Return True if *dt* is between 9:45 AM ET and 3:45 PM ET on a weekday.
|
||||||
|
|
||||||
The timestamp is first converted to US/Eastern time. Weekends are
|
The timestamp is first converted to US/Eastern time. Weekends and
|
||||||
always outside the window. Market holidays are **not** checked
|
US market holidays (2026) are always outside the window.
|
||||||
(simplified implementation).
|
|
||||||
"""
|
"""
|
||||||
et_dt = dt.astimezone(ET)
|
et_dt = dt.astimezone(ET)
|
||||||
if et_dt.weekday() not in _WEEKDAYS:
|
if et_dt.weekday() not in _WEEKDAYS:
|
||||||
return False
|
return False
|
||||||
|
if et_dt.date() in _HOLIDAYS_2026:
|
||||||
|
return False
|
||||||
t = et_dt.time()
|
t = et_dt.time()
|
||||||
return WINDOW_OPEN <= t < WINDOW_CLOSE
|
return WINDOW_OPEN <= t < WINDOW_CLOSE
|
||||||
|
|
||||||
@@ -74,9 +107,14 @@ def next_window_open(dt: datetime) -> datetime:
|
|||||||
|
|
||||||
|
|
||||||
def is_market_open(dt: datetime) -> bool:
|
def is_market_open(dt: datetime) -> bool:
|
||||||
"""Return True if *dt* is during US market hours (9:30 AM – 4:00 PM ET) on a weekday."""
|
"""Return True if *dt* is during US market hours (9:30 AM – 4:00 PM ET) on a weekday.
|
||||||
|
|
||||||
|
Returns False on weekends and US market holidays (2026).
|
||||||
|
"""
|
||||||
et_dt = dt.astimezone(ET)
|
et_dt = dt.astimezone(ET)
|
||||||
if et_dt.weekday() not in _WEEKDAYS:
|
if et_dt.weekday() not in _WEEKDAYS:
|
||||||
return False
|
return False
|
||||||
|
if et_dt.date() in _HOLIDAYS_2026:
|
||||||
|
return False
|
||||||
t = et_dt.time()
|
t = et_dt.time()
|
||||||
return MARKET_OPEN <= t < MARKET_CLOSE
|
return MARKET_OPEN <= t < MARKET_CLOSE
|
||||||
|
|||||||
Reference in New Issue
Block a user