Files
trader-ml/scripts/quick_benchmark.py
Tika ff8d58e1aa Phase 4c-bis/4d : CNN Image vectorisé + Agent RL PPO + HMM persistence + scripts
CNN Image (Phase 4c-bis) :
- chart_renderer.py : renderer numpy vectorisé (boucle 64 bougies, pas 12 000 fenêtres)
  → 1 068 img/s, GIL libéré entre itérations, API réactive pendant l'entraînement
- cnn_image_strategy_model.py : torch.set_num_threads(4) pour préserver l'event loop
- trading.py : asyncio.create_task() au lieu de background_tasks → hot-reloads non-bloquants

Agent RL PPO (Phase 4d) :
- src/ml/rl/ : TradingEnv (gymnasium), PPOModel (Actor-Critic MLP, GAE), RLStrategyModel
- src/strategies/rl_driven/ : RLDrivenStrategy (interface BaseStrategy complète)
- Routes API : POST /train-rl, GET /train-rl/{job_id}, GET /rl-models
- docs/RL_STRATEGY_GUIDE.md : documentation complète

HMM Persistence :
- regime_detector.py : save()/load()/needs_retrain()/is_trained (joblib + JSON meta)
- trading.py /ml/status : charge depuis disque si < 24h, re-entraîne + sauvegarde sinon
  → premier appel ~2s, appels suivants < 100ms

Scripts utilitaires :
- scripts/compare_strategies.py : backtest comparatif toutes stratégies (tabulate/JSON)
- scripts/quick_benchmark.py : comparaison wf_accuracy/precision des modèles ML sauvegardés
- reports/ : répertoire pour les rapports JSON générés

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 22:40:52 +00:00

328 lines
12 KiB
Python
Executable File

#!/usr/bin/env python3
"""
quick_benchmark.py — Benchmark rapide des modèles ML entraînés.
Lit les fichiers *_meta.json dans les répertoires de modèles :
- models/ml_strategy/ (XGBoost / LightGBM / RandomForest)
- models/cnn_strategy/ (CNN 1D — CandlestickEncoder)
- models/cnn_image_strategy/ (CNN Image — chandeliers en image)
Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc.
Indique quel modèle obtient la meilleure précision walk-forward.
Usage :
python scripts/quick_benchmark.py
python scripts/quick_benchmark.py --models-root /chemin/vers/models
"""
import argparse
import json
import sys
from pathlib import Path
# ---------------------------------------------------------------------------
# Répertoires de modèles (relatifs à la racine du projet)
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).parent.parent
# Mapping : nom affiché -> chemin relatif au projet
MODEL_DIRS = {
"ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy",
"CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy",
"CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy",
}
# Colonnes de métriques affichées dans le tableau
COLONNES = [
("type", "Type", "{:<28}"),
("symbol", "Symbol", "{:<8}"),
("timeframe", "TF", "{:<5}"),
("model_type", "Modèle", "{:<12}"),
("n_samples", "Échantillons", "{:>12}"),
("wf_accuracy", "WF Accuracy", "{:>12}"),
("wf_precision", "WF Precision", "{:>13}"),
("trained_at", "Entraîné le", "{:<20}"),
]
# ---------------------------------------------------------------------------
# Lecture des méta-fichiers JSON
# ---------------------------------------------------------------------------
def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]:
"""
Parcourt les répertoires de modèles et charge les métadonnées JSON.
Retourne une liste de dicts prêts pour l'affichage.
"""
tous = []
for type_label, dossier in model_dirs.items():
if not dossier.exists():
# Répertoire absent — pas encore de modèles entraînés pour ce type
continue
fichiers_meta = sorted(dossier.glob("*_meta.json"))
if not fichiers_meta:
continue
for f in fichiers_meta:
try:
with open(f, encoding="utf-8") as fp:
meta = json.load(fp)
except (json.JSONDecodeError, OSError) as e:
print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr)
continue
wf = meta.get("wf_metrics", {})
# Extraction des champs — avec valeurs par défaut
# Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé
# comme fallback si les champs ne sont pas dans le JSON.
stem_parts = f.stem.replace("_meta", "").split("_")
symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?")
timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?")
model_type = meta.get(
"model_type",
stem_parts[2] if len(stem_parts) > 2 else dossier.name
)
wf_accuracy = wf.get("avg_accuracy", None)
wf_precision = wf.get("avg_precision", None)
trained_at = meta.get("trained_at", "?")
# Tronque la date ISO à 19 caractères pour la lisibilité
if isinstance(trained_at, str) and len(trained_at) > 19:
trained_at = trained_at[:19]
tous.append({
"type": type_label,
"symbol": symbol,
"timeframe": timeframe,
"model_type": model_type,
"n_samples": meta.get("n_samples", 0),
"wf_accuracy": wf_accuracy,
"wf_precision": wf_precision,
"trained_at": trained_at,
"_meta_path": str(f), # pour le rapport détaillé éventuel
"_raw_meta": meta,
})
return tous
# ---------------------------------------------------------------------------
# Affichage du tableau
# ---------------------------------------------------------------------------
def formater_val(val, fmt: str) -> str:
"""Formate une valeur numérique ou retourne 'N/A'."""
if val is None:
return "N/A"
try:
if "%" in fmt or "f" in fmt:
return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val))
return fmt.format(val)
except (ValueError, TypeError):
return str(val)
def afficher_tableau(modeles: list[dict]) -> None:
"""Affiche le tableau comparatif des modèles dans le terminal."""
if not modeles:
print("Aucun modèle entraîné trouvé.")
print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).")
return
# Entêtes
headers = [label for _, label, _ in COLONNES]
col_fmts = [fmt for _, _, fmt in COLONNES]
# Construction des lignes
rows = []
for m in modeles:
ligne = []
for key, _, fmt in COLONNES:
val = m.get(key)
if val is None:
ligne.append("N/A")
elif key in ("wf_accuracy", "wf_precision"):
# Affichage en pourcentage
try:
ligne.append(f"{float(val):.2%}")
except (ValueError, TypeError):
ligne.append("N/A")
elif key == "n_samples":
try:
ligne.append(f"{int(val):,}")
except (ValueError, TypeError):
ligne.append(str(val))
else:
ligne.append(str(val))
rows.append(ligne)
print()
print("=" * 80)
print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS")
print("=" * 80)
try:
from tabulate import tabulate
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
except ImportError:
# Fallback : affichage simple
col_widths = [
max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
for i, h in enumerate(headers)
]
sep = " ".join("-" * w for w in col_widths)
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
print(header_line)
print(sep)
for row in rows:
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
print()
# ---------------------------------------------------------------------------
# Identification du meilleur modèle
# ---------------------------------------------------------------------------
def identifier_meilleur(modeles: list[dict]) -> None:
"""Indique quel modèle est le meilleur selon wf_accuracy et wf_precision."""
candidats = [m for m in modeles if m.get("wf_accuracy") is not None]
if not candidats:
print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).")
return
# Critère principal : wf_accuracy ; critère secondaire : wf_precision
meilleur = max(
candidats,
key=lambda m: (
m.get("wf_accuracy") or 0.0,
m.get("wf_precision") or 0.0,
),
)
print(f"Meilleur modèle (WF Accuracy) :")
print(f" Type : {meilleur['type']}")
print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}")
print(f" Modèle : {meilleur['model_type']}")
acc = meilleur.get("wf_accuracy")
prec = meilleur.get("wf_precision")
print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A")
print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A")
print(f" Entraîné le : {meilleur.get('trained_at', '?')}")
print(f" Fichier meta : {meilleur['_meta_path']}")
# Avertissement si la précision est insuffisante pour le trading
if acc is not None and acc < 0.40:
print()
print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.")
print(" Envisagez un re-entraînement avec plus de données ou de features.")
print()
# ---------------------------------------------------------------------------
# Statistiques globales par type de modèle
# ---------------------------------------------------------------------------
def afficher_stats_globales(modeles: list[dict]) -> None:
"""Affiche des statistiques agrégées par type de modèle."""
if not modeles:
return
# Regroupement par type
par_type: dict[str, list] = {}
for m in modeles:
par_type.setdefault(m["type"], []).append(m)
print("Résumé par type de modèle :")
print("-" * 50)
for type_label, groupe in par_type.items():
accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None]
precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None]
n_tot = sum(m.get("n_samples", 0) for m in groupe)
acc_moy = sum(accs) / len(accs) if accs else None
prec_moy = sum(precs) / len(precs) if precs else None
acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A"
prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A"
print(f" {type_label}")
print(f" Modèles entraînés : {len(groupe)}")
print(f" WF Accuracy moy. : {acc_str}")
print(f" WF Precision moy. : {prec_str}")
print(f" Échantillons tot. : {n_tot:,}")
print()
# ---------------------------------------------------------------------------
# Point d'entrée principal
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
"""Parse les arguments en ligne de commande."""
parser = argparse.ArgumentParser(
description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)."
)
parser.add_argument(
"--models-root",
type=Path,
default=PROJECT_ROOT / "models",
help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
# Mise à jour des chemins si --models-root est spécifié
dirs_effectifs = {
label: args.models_root / dossier.name
for label, dossier in MODEL_DIRS.items()
}
print()
print("Lecture des modèles dans :")
for label, chemin in dirs_effectifs.items():
existe = "OK" if chemin.exists() else "absent"
print(f" [{existe}] {chemin}")
print()
# Chargement de tous les modèles
modeles = charger_modeles(dirs_effectifs)
if not modeles:
print("Aucun modèle trouvé dans les répertoires indiqués.")
print()
print("Pour entraîner un modèle XGBoost/LightGBM :")
print(" POST http://localhost:8100/trading/train")
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}")
print()
print("Pour entraîner un modèle CNN 1D :")
print(" POST http://localhost:8100/trading/train-cnn")
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}")
sys.exit(0)
# Affichage du tableau
afficher_tableau(modeles)
# Statistiques par type
afficher_stats_globales(modeles)
# Meilleur modèle
identifier_meilleur(modeles)
print(f"Total : {len(modeles)} modèle(s) trouvé(s).")
print()
if __name__ == "__main__":
main()