#!/usr/bin/env python3 """ compare_strategies.py — Backtest comparatif automatisé des stratégies de trading. Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven) sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau comparatif et sauvegarde le rapport JSON dans reports/. Usage : python scripts/compare_strategies.py python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000 """ import argparse import json import sys import time from datetime import datetime from pathlib import Path # --------------------------------------------------------------------------- # Vérification de la dépendance httpx (ou fallback vers requests) # --------------------------------------------------------------------------- try: import httpx _HTTP_BACKEND = "httpx" except ImportError: try: import requests as _requests_module _HTTP_BACKEND = "requests" except ImportError: print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx") sys.exit(1) # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- API_BASE_URL = "http://localhost:8100" POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes) TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes) REPORTS_DIR = Path(__file__).parent.parent / "reports" # Stratégies à comparer — dans l'ordre d'affichage. # NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven # La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest # (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement # si l'API évolue pour la supporter. STRATEGIES = [ "scalping", "ml_driven", "cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement ] # Colonnes métriques à collecter et afficher METRICS_COLS = [ ("total_return", "Retour total", "{:.2%}"), ("sharpe_ratio", "Sharpe ratio", "{:.3f}"), ("max_drawdown", "Max drawdown", "{:.2%}"), ("win_rate", "Win rate", "{:.2%}"), ("total_trades", "Trades totaux", "{:d}"), ("profit_factor", "Profit factor", "{:.2f}"), ] # Seuils de validation (identiques à ceux de l'API) SEUIL_SHARPE = 1.5 SEUIL_DRAWDOWN_MAX = 0.10 SEUIL_WIN_RATE = 0.55 # --------------------------------------------------------------------------- # Couche HTTP (httpx ou requests selon disponibilité) # --------------------------------------------------------------------------- def _post(url: str, payload: dict, timeout: int = 30) -> dict: """Effectue un POST JSON et retourne la réponse parsée.""" if _HTTP_BACKEND == "httpx": resp = httpx.post(url, json=payload, timeout=timeout) resp.raise_for_status() return resp.json() else: resp = _requests_module.post(url, json=payload, timeout=timeout) resp.raise_for_status() return resp.json() def _get(url: str, timeout: int = 15) -> dict: """Effectue un GET et retourne la réponse parsée.""" if _HTTP_BACKEND == "httpx": resp = httpx.get(url, timeout=timeout) resp.raise_for_status() return resp.json() else: resp = _requests_module.get(url, timeout=timeout) resp.raise_for_status() return resp.json() # --------------------------------------------------------------------------- # Fonctions backtest # --------------------------------------------------------------------------- def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str: """ Lance un job de backtest via POST /trading/backtest. Retourne le job_id ou lève une exception en cas d'erreur. """ payload = { "strategy": strategy, "symbol": symbol, "period": period, "initial_capital": capital, } url = f"{API_BASE_URL}/trading/backtest" data = _post(url, payload) return data["job_id"] def attendre_resultat(job_id: str, strategy: str) -> dict: """ Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes jusqu'à ce que le statut soit 'completed' ou 'failed'. Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout. """ url = f"{API_BASE_URL}/trading/backtest/{job_id}" debut = time.time() while True: elapsed = time.time() - debut if elapsed > TIMEOUT_SEC: raise TimeoutError( f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours." ) data = _get(url) statut = data.get("status", "?") print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)") if statut == "completed": return data elif statut == "failed": erreur = data.get("error", "raison inconnue") raise RuntimeError(f"[{strategy}] Job échoué : {erreur}") # Toujours en cours — on attend time.sleep(POLL_INTERVAL_SEC) # --------------------------------------------------------------------------- # Affichage du tableau comparatif # --------------------------------------------------------------------------- def afficher_tableau(resultats: list[dict]) -> None: """ Affiche un tableau comparatif dans le terminal. Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple. """ # Construction des données tabulaires headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"] rows = [] for res in resultats: nom = res["strategy"] ligne = [nom] for key, _, fmt in METRICS_COLS: val = res.get(key) if val is None: ligne.append("N/A") else: try: if key == "total_trades": ligne.append(fmt.format(int(val))) else: ligne.append(fmt.format(float(val))) except (ValueError, TypeError): ligne.append(str(val)) # Indicateur de validation valide = res.get("is_valid_for_paper") if valide is True: ligne.append("OUI") elif valide is False: ligne.append("NON") else: ligne.append("?") rows.append(ligne) print() print("=" * 70) print("RAPPORT COMPARATIF DES STRATÉGIES") print("=" * 70) try: from tabulate import tabulate print(tabulate(rows, headers=headers, tablefmt="rounded_outline")) except ImportError: # Fallback : affichage simple sans tabulate col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0)) for i, h in enumerate(headers)] sep = " ".join("-" * w for w in col_widths) header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths)) print(header_line) print(sep) for row in rows: print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths))) print() # Identification de la meilleure stratégie (selon Sharpe ratio) valides = [r for r in resultats if r.get("sharpe_ratio") is not None] if valides: meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999) print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} " f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})") # Rappel des seuils print() print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | " f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | " f"Win rate >= {SEUIL_WIN_RATE:.0%}") print("=" * 70) # --------------------------------------------------------------------------- # Sauvegarde du rapport JSON # --------------------------------------------------------------------------- def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path: """ Sauvegarde le rapport de comparaison en JSON dans reports/. Retourne le chemin du fichier créé. """ REPORTS_DIR.mkdir(parents=True, exist_ok=True) horodatage = datetime.now().strftime("%Y%m%d_%H%M%S") nom_fichier = f"backtest_comparison_{horodatage}.json" chemin = REPORTS_DIR / nom_fichier rapport = { "generated_at": datetime.now().isoformat(), "parametres": { "symbol": symbol, "period": period, "api_base_url": API_BASE_URL, }, "seuils_validation": { "sharpe_ratio_min": SEUIL_SHARPE, "max_drawdown_max": SEUIL_DRAWDOWN_MAX, "win_rate_min": SEUIL_WIN_RATE, }, "resultats": resultats, } with open(chemin, "w", encoding="utf-8") as f: json.dump(rapport, f, indent=2, ensure_ascii=False, default=str) return chemin # --------------------------------------------------------------------------- # Point d'entrée principal # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: """Parse les arguments en ligne de commande.""" parser = argparse.ArgumentParser( description="Compare les stratégies de trading via backtest API." ) parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)") parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)") parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)") parser.add_argument( "--strategies", nargs="+", default=STRATEGIES, help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)", ) return parser.parse_args() def main() -> None: args = parse_args() print() print("=" * 70) print("BACKTEST COMPARATIF DES STRATÉGIES") print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR") print(f"Stratégies: {', '.join(args.strategies)}") print(f"API: {API_BASE_URL}") print("=" * 70) # Vérification de la disponibilité de l'API try: _get(f"{API_BASE_URL}/health") print("API disponible.") except Exception as e: print(f"ERREUR : Impossible de joindre l'API ({e})") print("Vérifiez que le container trading-api tourne sur le port 8100.") sys.exit(1) print() # ----------------------------------------------------------------------- # Phase 1 : lancement de tous les backtests # ----------------------------------------------------------------------- jobs: dict[str, str] = {} # {strategy: job_id} stratégies_échouées: list[str] = [] for strat in args.strategies: print(f"Lancement backtest [{strat}]...") try: job_id = lancer_backtest(strat, args.symbol, args.period, args.capital) jobs[strat] = job_id print(f" -> Job ID : {job_id}") except Exception as e: print(f" ERREUR lancement [{strat}] : {e}") stratégies_échouées.append(strat) if not jobs: print("Aucun job lancé. Arrêt.") sys.exit(1) print() # ----------------------------------------------------------------------- # Phase 2 : attente et collecte des résultats (séquentielle) # ----------------------------------------------------------------------- resultats: list[dict] = [] for strat, job_id in jobs.items(): print(f"Attente résultat [{strat}] (job={job_id})...") try: res = attendre_resultat(job_id, strat) resultats.append(res) print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, " f"Retour={res.get('total_return', 'N/A')}") except Exception as e: print(f" ERREUR [{strat}] : {e}") # On ajoute quand même un résultat partiel pour la lisibilité du rapport resultats.append({ "strategy": strat, "symbol": args.symbol, "status": "failed", "error": str(e), }) print() # ----------------------------------------------------------------------- # Phase 3 : affichage et sauvegarde # ----------------------------------------------------------------------- afficher_tableau(resultats) chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period) print(f"Rapport JSON sauvegardé : {chemin_rapport}") print() # Résumé des échecs éventuels if stratégies_échouées: print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}") if __name__ == "__main__": main()