Phase 4c-bis/4d : CNN Image vectorisé + Agent RL PPO + HMM persistence + scripts
CNN Image (Phase 4c-bis) :
- chart_renderer.py : renderer numpy vectorisé (boucle 64 bougies, pas 12 000 fenêtres)
→ 1 068 img/s, GIL libéré entre itérations, API réactive pendant l'entraînement
- cnn_image_strategy_model.py : torch.set_num_threads(4) pour préserver l'event loop
- trading.py : asyncio.create_task() au lieu de background_tasks → hot-reloads non-bloquants
Agent RL PPO (Phase 4d) :
- src/ml/rl/ : TradingEnv (gymnasium), PPOModel (Actor-Critic MLP, GAE), RLStrategyModel
- src/strategies/rl_driven/ : RLDrivenStrategy (interface BaseStrategy complète)
- Routes API : POST /train-rl, GET /train-rl/{job_id}, GET /rl-models
- docs/RL_STRATEGY_GUIDE.md : documentation complète
HMM Persistence :
- regime_detector.py : save()/load()/needs_retrain()/is_trained (joblib + JSON meta)
- trading.py /ml/status : charge depuis disque si < 24h, re-entraîne + sauvegarde sinon
→ premier appel ~2s, appels suivants < 100ms
Scripts utilitaires :
- scripts/compare_strategies.py : backtest comparatif toutes stratégies (tabulate/JSON)
- scripts/quick_benchmark.py : comparaison wf_accuracy/precision des modèles ML sauvegardés
- reports/ : répertoire pour les rapports JSON générés
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
365
scripts/compare_strategies.py
Executable file
365
scripts/compare_strategies.py
Executable file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
compare_strategies.py — Backtest comparatif automatisé des stratégies de trading.
|
||||
|
||||
Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven)
|
||||
sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau
|
||||
comparatif et sauvegarde le rapport JSON dans reports/.
|
||||
|
||||
Usage :
|
||||
python scripts/compare_strategies.py
|
||||
python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vérification de la dépendance httpx (ou fallback vers requests)
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import httpx
|
||||
_HTTP_BACKEND = "httpx"
|
||||
except ImportError:
|
||||
try:
|
||||
import requests as _requests_module
|
||||
_HTTP_BACKEND = "requests"
|
||||
except ImportError:
|
||||
print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx")
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
API_BASE_URL = "http://localhost:8100"
|
||||
POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes)
|
||||
TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes)
|
||||
REPORTS_DIR = Path(__file__).parent.parent / "reports"
|
||||
|
||||
# Stratégies à comparer — dans l'ordre d'affichage.
|
||||
# NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven
|
||||
# La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest
|
||||
# (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement
|
||||
# si l'API évolue pour la supporter.
|
||||
STRATEGIES = [
|
||||
"scalping",
|
||||
"ml_driven",
|
||||
"cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement
|
||||
]
|
||||
|
||||
# Colonnes métriques à collecter et afficher
|
||||
METRICS_COLS = [
|
||||
("total_return", "Retour total", "{:.2%}"),
|
||||
("sharpe_ratio", "Sharpe ratio", "{:.3f}"),
|
||||
("max_drawdown", "Max drawdown", "{:.2%}"),
|
||||
("win_rate", "Win rate", "{:.2%}"),
|
||||
("total_trades", "Trades totaux", "{:d}"),
|
||||
("profit_factor", "Profit factor", "{:.2f}"),
|
||||
]
|
||||
|
||||
# Seuils de validation (identiques à ceux de l'API)
|
||||
SEUIL_SHARPE = 1.5
|
||||
SEUIL_DRAWDOWN_MAX = 0.10
|
||||
SEUIL_WIN_RATE = 0.55
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Couche HTTP (httpx ou requests selon disponibilité)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _post(url: str, payload: dict, timeout: int = 30) -> dict:
|
||||
"""Effectue un POST JSON et retourne la réponse parsée."""
|
||||
if _HTTP_BACKEND == "httpx":
|
||||
resp = httpx.post(url, json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
else:
|
||||
resp = _requests_module.post(url, json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _get(url: str, timeout: int = 15) -> dict:
|
||||
"""Effectue un GET et retourne la réponse parsée."""
|
||||
if _HTTP_BACKEND == "httpx":
|
||||
resp = httpx.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
else:
|
||||
resp = _requests_module.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fonctions backtest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str:
|
||||
"""
|
||||
Lance un job de backtest via POST /trading/backtest.
|
||||
|
||||
Retourne le job_id ou lève une exception en cas d'erreur.
|
||||
"""
|
||||
payload = {
|
||||
"strategy": strategy,
|
||||
"symbol": symbol,
|
||||
"period": period,
|
||||
"initial_capital": capital,
|
||||
}
|
||||
url = f"{API_BASE_URL}/trading/backtest"
|
||||
data = _post(url, payload)
|
||||
return data["job_id"]
|
||||
|
||||
|
||||
def attendre_resultat(job_id: str, strategy: str) -> dict:
|
||||
"""
|
||||
Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes
|
||||
jusqu'à ce que le statut soit 'completed' ou 'failed'.
|
||||
|
||||
Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout.
|
||||
"""
|
||||
url = f"{API_BASE_URL}/trading/backtest/{job_id}"
|
||||
debut = time.time()
|
||||
|
||||
while True:
|
||||
elapsed = time.time() - debut
|
||||
if elapsed > TIMEOUT_SEC:
|
||||
raise TimeoutError(
|
||||
f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours."
|
||||
)
|
||||
|
||||
data = _get(url)
|
||||
statut = data.get("status", "?")
|
||||
|
||||
print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)")
|
||||
|
||||
if statut == "completed":
|
||||
return data
|
||||
elif statut == "failed":
|
||||
erreur = data.get("error", "raison inconnue")
|
||||
raise RuntimeError(f"[{strategy}] Job échoué : {erreur}")
|
||||
|
||||
# Toujours en cours — on attend
|
||||
time.sleep(POLL_INTERVAL_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Affichage du tableau comparatif
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def afficher_tableau(resultats: list[dict]) -> None:
|
||||
"""
|
||||
Affiche un tableau comparatif dans le terminal.
|
||||
Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple.
|
||||
"""
|
||||
# Construction des données tabulaires
|
||||
headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"]
|
||||
rows = []
|
||||
|
||||
for res in resultats:
|
||||
nom = res["strategy"]
|
||||
ligne = [nom]
|
||||
for key, _, fmt in METRICS_COLS:
|
||||
val = res.get(key)
|
||||
if val is None:
|
||||
ligne.append("N/A")
|
||||
else:
|
||||
try:
|
||||
if key == "total_trades":
|
||||
ligne.append(fmt.format(int(val)))
|
||||
else:
|
||||
ligne.append(fmt.format(float(val)))
|
||||
except (ValueError, TypeError):
|
||||
ligne.append(str(val))
|
||||
# Indicateur de validation
|
||||
valide = res.get("is_valid_for_paper")
|
||||
if valide is True:
|
||||
ligne.append("OUI")
|
||||
elif valide is False:
|
||||
ligne.append("NON")
|
||||
else:
|
||||
ligne.append("?")
|
||||
rows.append(ligne)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("RAPPORT COMPARATIF DES STRATÉGIES")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||
except ImportError:
|
||||
# Fallback : affichage simple sans tabulate
|
||||
col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||
for i, h in enumerate(headers)]
|
||||
sep = " ".join("-" * w for w in col_widths)
|
||||
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||
print(header_line)
|
||||
print(sep)
|
||||
for row in rows:
|
||||
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||
|
||||
print()
|
||||
|
||||
# Identification de la meilleure stratégie (selon Sharpe ratio)
|
||||
valides = [r for r in resultats if r.get("sharpe_ratio") is not None]
|
||||
if valides:
|
||||
meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999)
|
||||
print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} "
|
||||
f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})")
|
||||
|
||||
# Rappel des seuils
|
||||
print()
|
||||
print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | "
|
||||
f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | "
|
||||
f"Win rate >= {SEUIL_WIN_RATE:.0%}")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sauvegarde du rapport JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path:
|
||||
"""
|
||||
Sauvegarde le rapport de comparaison en JSON dans reports/.
|
||||
|
||||
Retourne le chemin du fichier créé.
|
||||
"""
|
||||
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
horodatage = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
nom_fichier = f"backtest_comparison_{horodatage}.json"
|
||||
chemin = REPORTS_DIR / nom_fichier
|
||||
|
||||
rapport = {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"parametres": {
|
||||
"symbol": symbol,
|
||||
"period": period,
|
||||
"api_base_url": API_BASE_URL,
|
||||
},
|
||||
"seuils_validation": {
|
||||
"sharpe_ratio_min": SEUIL_SHARPE,
|
||||
"max_drawdown_max": SEUIL_DRAWDOWN_MAX,
|
||||
"win_rate_min": SEUIL_WIN_RATE,
|
||||
},
|
||||
"resultats": resultats,
|
||||
}
|
||||
|
||||
with open(chemin, "w", encoding="utf-8") as f:
|
||||
json.dump(rapport, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
return chemin
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Point d'entrée principal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse les arguments en ligne de commande."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare les stratégies de trading via backtest API."
|
||||
)
|
||||
parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)")
|
||||
parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)")
|
||||
parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)")
|
||||
parser.add_argument(
|
||||
"--strategies",
|
||||
nargs="+",
|
||||
default=STRATEGIES,
|
||||
help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("BACKTEST COMPARATIF DES STRATÉGIES")
|
||||
print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR")
|
||||
print(f"Stratégies: {', '.join(args.strategies)}")
|
||||
print(f"API: {API_BASE_URL}")
|
||||
print("=" * 70)
|
||||
|
||||
# Vérification de la disponibilité de l'API
|
||||
try:
|
||||
_get(f"{API_BASE_URL}/health")
|
||||
print("API disponible.")
|
||||
except Exception as e:
|
||||
print(f"ERREUR : Impossible de joindre l'API ({e})")
|
||||
print("Vérifiez que le container trading-api tourne sur le port 8100.")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1 : lancement de tous les backtests
|
||||
# -----------------------------------------------------------------------
|
||||
jobs: dict[str, str] = {} # {strategy: job_id}
|
||||
stratégies_échouées: list[str] = []
|
||||
|
||||
for strat in args.strategies:
|
||||
print(f"Lancement backtest [{strat}]...")
|
||||
try:
|
||||
job_id = lancer_backtest(strat, args.symbol, args.period, args.capital)
|
||||
jobs[strat] = job_id
|
||||
print(f" -> Job ID : {job_id}")
|
||||
except Exception as e:
|
||||
print(f" ERREUR lancement [{strat}] : {e}")
|
||||
stratégies_échouées.append(strat)
|
||||
|
||||
if not jobs:
|
||||
print("Aucun job lancé. Arrêt.")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 2 : attente et collecte des résultats (séquentielle)
|
||||
# -----------------------------------------------------------------------
|
||||
resultats: list[dict] = []
|
||||
|
||||
for strat, job_id in jobs.items():
|
||||
print(f"Attente résultat [{strat}] (job={job_id})...")
|
||||
try:
|
||||
res = attendre_resultat(job_id, strat)
|
||||
resultats.append(res)
|
||||
print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, "
|
||||
f"Retour={res.get('total_return', 'N/A')}")
|
||||
except Exception as e:
|
||||
print(f" ERREUR [{strat}] : {e}")
|
||||
# On ajoute quand même un résultat partiel pour la lisibilité du rapport
|
||||
resultats.append({
|
||||
"strategy": strat,
|
||||
"symbol": args.symbol,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 3 : affichage et sauvegarde
|
||||
# -----------------------------------------------------------------------
|
||||
afficher_tableau(resultats)
|
||||
|
||||
chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period)
|
||||
print(f"Rapport JSON sauvegardé : {chemin_rapport}")
|
||||
print()
|
||||
|
||||
# Résumé des échecs éventuels
|
||||
if stratégies_échouées:
|
||||
print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
327
scripts/quick_benchmark.py
Executable file
327
scripts/quick_benchmark.py
Executable file
@@ -0,0 +1,327 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
quick_benchmark.py — Benchmark rapide des modèles ML entraînés.
|
||||
|
||||
Lit les fichiers *_meta.json dans les répertoires de modèles :
|
||||
- models/ml_strategy/ (XGBoost / LightGBM / RandomForest)
|
||||
- models/cnn_strategy/ (CNN 1D — CandlestickEncoder)
|
||||
- models/cnn_image_strategy/ (CNN Image — chandeliers en image)
|
||||
|
||||
Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc.
|
||||
Indique quel modèle obtient la meilleure précision walk-forward.
|
||||
|
||||
Usage :
|
||||
python scripts/quick_benchmark.py
|
||||
python scripts/quick_benchmark.py --models-root /chemin/vers/models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Répertoires de modèles (relatifs à la racine du projet)
|
||||
# ---------------------------------------------------------------------------
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
|
||||
# Mapping : nom affiché -> chemin relatif au projet
|
||||
MODEL_DIRS = {
|
||||
"ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy",
|
||||
"CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy",
|
||||
"CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy",
|
||||
}
|
||||
|
||||
# Colonnes de métriques affichées dans le tableau
|
||||
COLONNES = [
|
||||
("type", "Type", "{:<28}"),
|
||||
("symbol", "Symbol", "{:<8}"),
|
||||
("timeframe", "TF", "{:<5}"),
|
||||
("model_type", "Modèle", "{:<12}"),
|
||||
("n_samples", "Échantillons", "{:>12}"),
|
||||
("wf_accuracy", "WF Accuracy", "{:>12}"),
|
||||
("wf_precision", "WF Precision", "{:>13}"),
|
||||
("trained_at", "Entraîné le", "{:<20}"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lecture des méta-fichiers JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]:
|
||||
"""
|
||||
Parcourt les répertoires de modèles et charge les métadonnées JSON.
|
||||
|
||||
Retourne une liste de dicts prêts pour l'affichage.
|
||||
"""
|
||||
tous = []
|
||||
|
||||
for type_label, dossier in model_dirs.items():
|
||||
if not dossier.exists():
|
||||
# Répertoire absent — pas encore de modèles entraînés pour ce type
|
||||
continue
|
||||
|
||||
fichiers_meta = sorted(dossier.glob("*_meta.json"))
|
||||
if not fichiers_meta:
|
||||
continue
|
||||
|
||||
for f in fichiers_meta:
|
||||
try:
|
||||
with open(f, encoding="utf-8") as fp:
|
||||
meta = json.load(fp)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
wf = meta.get("wf_metrics", {})
|
||||
|
||||
# Extraction des champs — avec valeurs par défaut
|
||||
# Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé
|
||||
# comme fallback si les champs ne sont pas dans le JSON.
|
||||
stem_parts = f.stem.replace("_meta", "").split("_")
|
||||
symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?")
|
||||
timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?")
|
||||
model_type = meta.get(
|
||||
"model_type",
|
||||
stem_parts[2] if len(stem_parts) > 2 else dossier.name
|
||||
)
|
||||
|
||||
wf_accuracy = wf.get("avg_accuracy", None)
|
||||
wf_precision = wf.get("avg_precision", None)
|
||||
|
||||
trained_at = meta.get("trained_at", "?")
|
||||
# Tronque la date ISO à 19 caractères pour la lisibilité
|
||||
if isinstance(trained_at, str) and len(trained_at) > 19:
|
||||
trained_at = trained_at[:19]
|
||||
|
||||
tous.append({
|
||||
"type": type_label,
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"model_type": model_type,
|
||||
"n_samples": meta.get("n_samples", 0),
|
||||
"wf_accuracy": wf_accuracy,
|
||||
"wf_precision": wf_precision,
|
||||
"trained_at": trained_at,
|
||||
"_meta_path": str(f), # pour le rapport détaillé éventuel
|
||||
"_raw_meta": meta,
|
||||
})
|
||||
|
||||
return tous
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Affichage du tableau
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def formater_val(val, fmt: str) -> str:
|
||||
"""Formate une valeur numérique ou retourne 'N/A'."""
|
||||
if val is None:
|
||||
return "N/A"
|
||||
try:
|
||||
if "%" in fmt or "f" in fmt:
|
||||
return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val))
|
||||
return fmt.format(val)
|
||||
except (ValueError, TypeError):
|
||||
return str(val)
|
||||
|
||||
|
||||
def afficher_tableau(modeles: list[dict]) -> None:
|
||||
"""Affiche le tableau comparatif des modèles dans le terminal."""
|
||||
if not modeles:
|
||||
print("Aucun modèle entraîné trouvé.")
|
||||
print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).")
|
||||
return
|
||||
|
||||
# Entêtes
|
||||
headers = [label for _, label, _ in COLONNES]
|
||||
col_fmts = [fmt for _, _, fmt in COLONNES]
|
||||
|
||||
# Construction des lignes
|
||||
rows = []
|
||||
for m in modeles:
|
||||
ligne = []
|
||||
for key, _, fmt in COLONNES:
|
||||
val = m.get(key)
|
||||
if val is None:
|
||||
ligne.append("N/A")
|
||||
elif key in ("wf_accuracy", "wf_precision"):
|
||||
# Affichage en pourcentage
|
||||
try:
|
||||
ligne.append(f"{float(val):.2%}")
|
||||
except (ValueError, TypeError):
|
||||
ligne.append("N/A")
|
||||
elif key == "n_samples":
|
||||
try:
|
||||
ligne.append(f"{int(val):,}")
|
||||
except (ValueError, TypeError):
|
||||
ligne.append(str(val))
|
||||
else:
|
||||
ligne.append(str(val))
|
||||
rows.append(ligne)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||
except ImportError:
|
||||
# Fallback : affichage simple
|
||||
col_widths = [
|
||||
max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
sep = " ".join("-" * w for w in col_widths)
|
||||
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||
print(header_line)
|
||||
print(sep)
|
||||
for row in rows:
|
||||
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Identification du meilleur modèle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def identifier_meilleur(modeles: list[dict]) -> None:
|
||||
"""Indique quel modèle est le meilleur selon wf_accuracy et wf_precision."""
|
||||
candidats = [m for m in modeles if m.get("wf_accuracy") is not None]
|
||||
if not candidats:
|
||||
print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).")
|
||||
return
|
||||
|
||||
# Critère principal : wf_accuracy ; critère secondaire : wf_precision
|
||||
meilleur = max(
|
||||
candidats,
|
||||
key=lambda m: (
|
||||
m.get("wf_accuracy") or 0.0,
|
||||
m.get("wf_precision") or 0.0,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Meilleur modèle (WF Accuracy) :")
|
||||
print(f" Type : {meilleur['type']}")
|
||||
print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}")
|
||||
print(f" Modèle : {meilleur['model_type']}")
|
||||
|
||||
acc = meilleur.get("wf_accuracy")
|
||||
prec = meilleur.get("wf_precision")
|
||||
print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A")
|
||||
print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A")
|
||||
print(f" Entraîné le : {meilleur.get('trained_at', '?')}")
|
||||
print(f" Fichier meta : {meilleur['_meta_path']}")
|
||||
|
||||
# Avertissement si la précision est insuffisante pour le trading
|
||||
if acc is not None and acc < 0.40:
|
||||
print()
|
||||
print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.")
|
||||
print(" Envisagez un re-entraînement avec plus de données ou de features.")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Statistiques globales par type de modèle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def afficher_stats_globales(modeles: list[dict]) -> None:
|
||||
"""Affiche des statistiques agrégées par type de modèle."""
|
||||
if not modeles:
|
||||
return
|
||||
|
||||
# Regroupement par type
|
||||
par_type: dict[str, list] = {}
|
||||
for m in modeles:
|
||||
par_type.setdefault(m["type"], []).append(m)
|
||||
|
||||
print("Résumé par type de modèle :")
|
||||
print("-" * 50)
|
||||
for type_label, groupe in par_type.items():
|
||||
accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None]
|
||||
precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None]
|
||||
n_tot = sum(m.get("n_samples", 0) for m in groupe)
|
||||
|
||||
acc_moy = sum(accs) / len(accs) if accs else None
|
||||
prec_moy = sum(precs) / len(precs) if precs else None
|
||||
|
||||
acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A"
|
||||
prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A"
|
||||
|
||||
print(f" {type_label}")
|
||||
print(f" Modèles entraînés : {len(groupe)}")
|
||||
print(f" WF Accuracy moy. : {acc_str}")
|
||||
print(f" WF Precision moy. : {prec_str}")
|
||||
print(f" Échantillons tot. : {n_tot:,}")
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Point d'entrée principal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse les arguments en ligne de commande."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models-root",
|
||||
type=Path,
|
||||
default=PROJECT_ROOT / "models",
|
||||
help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Mise à jour des chemins si --models-root est spécifié
|
||||
dirs_effectifs = {
|
||||
label: args.models_root / dossier.name
|
||||
for label, dossier in MODEL_DIRS.items()
|
||||
}
|
||||
|
||||
print()
|
||||
print("Lecture des modèles dans :")
|
||||
for label, chemin in dirs_effectifs.items():
|
||||
existe = "OK" if chemin.exists() else "absent"
|
||||
print(f" [{existe}] {chemin}")
|
||||
print()
|
||||
|
||||
# Chargement de tous les modèles
|
||||
modeles = charger_modeles(dirs_effectifs)
|
||||
|
||||
if not modeles:
|
||||
print("Aucun modèle trouvé dans les répertoires indiqués.")
|
||||
print()
|
||||
print("Pour entraîner un modèle XGBoost/LightGBM :")
|
||||
print(" POST http://localhost:8100/trading/train")
|
||||
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}")
|
||||
print()
|
||||
print("Pour entraîner un modèle CNN 1D :")
|
||||
print(" POST http://localhost:8100/trading/train-cnn")
|
||||
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}")
|
||||
sys.exit(0)
|
||||
|
||||
# Affichage du tableau
|
||||
afficher_tableau(modeles)
|
||||
|
||||
# Statistiques par type
|
||||
afficher_stats_globales(modeles)
|
||||
|
||||
# Meilleur modèle
|
||||
identifier_meilleur(modeles)
|
||||
|
||||
print(f"Total : {len(modeles)} modèle(s) trouvé(s).")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user