Files
trader-ml/scripts/compare_strategies.py
Tika ff8d58e1aa Phase 4c-bis/4d : CNN Image vectorisé + Agent RL PPO + HMM persistence + scripts
CNN Image (Phase 4c-bis) :
- chart_renderer.py : renderer numpy vectorisé (boucle 64 bougies, pas 12 000 fenêtres)
  → 1 068 img/s, GIL libéré entre itérations, API réactive pendant l'entraînement
- cnn_image_strategy_model.py : torch.set_num_threads(4) pour préserver l'event loop
- trading.py : asyncio.create_task() au lieu de background_tasks → hot-reloads non-bloquants

Agent RL PPO (Phase 4d) :
- src/ml/rl/ : TradingEnv (gymnasium), PPOModel (Actor-Critic MLP, GAE), RLStrategyModel
- src/strategies/rl_driven/ : RLDrivenStrategy (interface BaseStrategy complète)
- Routes API : POST /train-rl, GET /train-rl/{job_id}, GET /rl-models
- docs/RL_STRATEGY_GUIDE.md : documentation complète

HMM Persistence :
- regime_detector.py : save()/load()/needs_retrain()/is_trained (joblib + JSON meta)
- trading.py /ml/status : charge depuis disque si < 24h, re-entraîne + sauvegarde sinon
  → premier appel ~2s, appels suivants < 100ms

Scripts utilitaires :
- scripts/compare_strategies.py : backtest comparatif toutes stratégies (tabulate/JSON)
- scripts/quick_benchmark.py : comparaison wf_accuracy/precision des modèles ML sauvegardés
- reports/ : répertoire pour les rapports JSON générés

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 22:40:52 +00:00

366 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
"""
compare_strategies.py — Backtest comparatif automatisé des stratégies de trading.
Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven)
sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau
comparatif et sauvegarde le rapport JSON dans reports/.
Usage :
python scripts/compare_strategies.py
python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000
"""
import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
# ---------------------------------------------------------------------------
# Vérification de la dépendance httpx (ou fallback vers requests)
# ---------------------------------------------------------------------------
try:
import httpx
_HTTP_BACKEND = "httpx"
except ImportError:
try:
import requests as _requests_module
_HTTP_BACKEND = "requests"
except ImportError:
print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx")
sys.exit(1)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = "http://localhost:8100"
POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes)
TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes)
REPORTS_DIR = Path(__file__).parent.parent / "reports"
# Stratégies à comparer — dans l'ordre d'affichage.
# NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven
# La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest
# (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement
# si l'API évolue pour la supporter.
STRATEGIES = [
"scalping",
"ml_driven",
"cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement
]
# Colonnes métriques à collecter et afficher
METRICS_COLS = [
("total_return", "Retour total", "{:.2%}"),
("sharpe_ratio", "Sharpe ratio", "{:.3f}"),
("max_drawdown", "Max drawdown", "{:.2%}"),
("win_rate", "Win rate", "{:.2%}"),
("total_trades", "Trades totaux", "{:d}"),
("profit_factor", "Profit factor", "{:.2f}"),
]
# Seuils de validation (identiques à ceux de l'API)
SEUIL_SHARPE = 1.5
SEUIL_DRAWDOWN_MAX = 0.10
SEUIL_WIN_RATE = 0.55
# ---------------------------------------------------------------------------
# Couche HTTP (httpx ou requests selon disponibilité)
# ---------------------------------------------------------------------------
def _post(url: str, payload: dict, timeout: int = 30) -> dict:
"""Effectue un POST JSON et retourne la réponse parsée."""
if _HTTP_BACKEND == "httpx":
resp = httpx.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()
else:
resp = _requests_module.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()
def _get(url: str, timeout: int = 15) -> dict:
"""Effectue un GET et retourne la réponse parsée."""
if _HTTP_BACKEND == "httpx":
resp = httpx.get(url, timeout=timeout)
resp.raise_for_status()
return resp.json()
else:
resp = _requests_module.get(url, timeout=timeout)
resp.raise_for_status()
return resp.json()
# ---------------------------------------------------------------------------
# Fonctions backtest
# ---------------------------------------------------------------------------
def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str:
"""
Lance un job de backtest via POST /trading/backtest.
Retourne le job_id ou lève une exception en cas d'erreur.
"""
payload = {
"strategy": strategy,
"symbol": symbol,
"period": period,
"initial_capital": capital,
}
url = f"{API_BASE_URL}/trading/backtest"
data = _post(url, payload)
return data["job_id"]
def attendre_resultat(job_id: str, strategy: str) -> dict:
"""
Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes
jusqu'à ce que le statut soit 'completed' ou 'failed'.
Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout.
"""
url = f"{API_BASE_URL}/trading/backtest/{job_id}"
debut = time.time()
while True:
elapsed = time.time() - debut
if elapsed > TIMEOUT_SEC:
raise TimeoutError(
f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours."
)
data = _get(url)
statut = data.get("status", "?")
print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)")
if statut == "completed":
return data
elif statut == "failed":
erreur = data.get("error", "raison inconnue")
raise RuntimeError(f"[{strategy}] Job échoué : {erreur}")
# Toujours en cours — on attend
time.sleep(POLL_INTERVAL_SEC)
# ---------------------------------------------------------------------------
# Affichage du tableau comparatif
# ---------------------------------------------------------------------------
def afficher_tableau(resultats: list[dict]) -> None:
"""
Affiche un tableau comparatif dans le terminal.
Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple.
"""
# Construction des données tabulaires
headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"]
rows = []
for res in resultats:
nom = res["strategy"]
ligne = [nom]
for key, _, fmt in METRICS_COLS:
val = res.get(key)
if val is None:
ligne.append("N/A")
else:
try:
if key == "total_trades":
ligne.append(fmt.format(int(val)))
else:
ligne.append(fmt.format(float(val)))
except (ValueError, TypeError):
ligne.append(str(val))
# Indicateur de validation
valide = res.get("is_valid_for_paper")
if valide is True:
ligne.append("OUI")
elif valide is False:
ligne.append("NON")
else:
ligne.append("?")
rows.append(ligne)
print()
print("=" * 70)
print("RAPPORT COMPARATIF DES STRATÉGIES")
print("=" * 70)
try:
from tabulate import tabulate
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
except ImportError:
# Fallback : affichage simple sans tabulate
col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
for i, h in enumerate(headers)]
sep = " ".join("-" * w for w in col_widths)
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
print(header_line)
print(sep)
for row in rows:
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
print()
# Identification de la meilleure stratégie (selon Sharpe ratio)
valides = [r for r in resultats if r.get("sharpe_ratio") is not None]
if valides:
meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999)
print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} "
f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})")
# Rappel des seuils
print()
print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | "
f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | "
f"Win rate >= {SEUIL_WIN_RATE:.0%}")
print("=" * 70)
# ---------------------------------------------------------------------------
# Sauvegarde du rapport JSON
# ---------------------------------------------------------------------------
def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path:
"""
Sauvegarde le rapport de comparaison en JSON dans reports/.
Retourne le chemin du fichier créé.
"""
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
horodatage = datetime.now().strftime("%Y%m%d_%H%M%S")
nom_fichier = f"backtest_comparison_{horodatage}.json"
chemin = REPORTS_DIR / nom_fichier
rapport = {
"generated_at": datetime.now().isoformat(),
"parametres": {
"symbol": symbol,
"period": period,
"api_base_url": API_BASE_URL,
},
"seuils_validation": {
"sharpe_ratio_min": SEUIL_SHARPE,
"max_drawdown_max": SEUIL_DRAWDOWN_MAX,
"win_rate_min": SEUIL_WIN_RATE,
},
"resultats": resultats,
}
with open(chemin, "w", encoding="utf-8") as f:
json.dump(rapport, f, indent=2, ensure_ascii=False, default=str)
return chemin
# ---------------------------------------------------------------------------
# Point d'entrée principal
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
"""Parse les arguments en ligne de commande."""
parser = argparse.ArgumentParser(
description="Compare les stratégies de trading via backtest API."
)
parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)")
parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)")
parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)")
parser.add_argument(
"--strategies",
nargs="+",
default=STRATEGIES,
help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
print()
print("=" * 70)
print("BACKTEST COMPARATIF DES STRATÉGIES")
print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR")
print(f"Stratégies: {', '.join(args.strategies)}")
print(f"API: {API_BASE_URL}")
print("=" * 70)
# Vérification de la disponibilité de l'API
try:
_get(f"{API_BASE_URL}/health")
print("API disponible.")
except Exception as e:
print(f"ERREUR : Impossible de joindre l'API ({e})")
print("Vérifiez que le container trading-api tourne sur le port 8100.")
sys.exit(1)
print()
# -----------------------------------------------------------------------
# Phase 1 : lancement de tous les backtests
# -----------------------------------------------------------------------
jobs: dict[str, str] = {} # {strategy: job_id}
stratégies_échouées: list[str] = []
for strat in args.strategies:
print(f"Lancement backtest [{strat}]...")
try:
job_id = lancer_backtest(strat, args.symbol, args.period, args.capital)
jobs[strat] = job_id
print(f" -> Job ID : {job_id}")
except Exception as e:
print(f" ERREUR lancement [{strat}] : {e}")
stratégies_échouées.append(strat)
if not jobs:
print("Aucun job lancé. Arrêt.")
sys.exit(1)
print()
# -----------------------------------------------------------------------
# Phase 2 : attente et collecte des résultats (séquentielle)
# -----------------------------------------------------------------------
resultats: list[dict] = []
for strat, job_id in jobs.items():
print(f"Attente résultat [{strat}] (job={job_id})...")
try:
res = attendre_resultat(job_id, strat)
resultats.append(res)
print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, "
f"Retour={res.get('total_return', 'N/A')}")
except Exception as e:
print(f" ERREUR [{strat}] : {e}")
# On ajoute quand même un résultat partiel pour la lisibilité du rapport
resultats.append({
"strategy": strat,
"symbol": args.symbol,
"status": "failed",
"error": str(e),
})
print()
# -----------------------------------------------------------------------
# Phase 3 : affichage et sauvegarde
# -----------------------------------------------------------------------
afficher_tableau(resultats)
chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period)
print(f"Rapport JSON sauvegardé : {chemin_rapport}")
print()
# Résumé des échecs éventuels
if stratégies_échouées:
print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}")
if __name__ == "__main__":
main()