#!/usr/bin/env python3 """ quick_benchmark.py — Benchmark rapide des modèles ML entraînés. Lit les fichiers *_meta.json dans les répertoires de modèles : - models/ml_strategy/ (XGBoost / LightGBM / RandomForest) - models/cnn_strategy/ (CNN 1D — CandlestickEncoder) - models/cnn_image_strategy/ (CNN Image — chandeliers en image) Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc. Indique quel modèle obtient la meilleure précision walk-forward. Usage : python scripts/quick_benchmark.py python scripts/quick_benchmark.py --models-root /chemin/vers/models """ import argparse import json import sys from pathlib import Path # --------------------------------------------------------------------------- # Répertoires de modèles (relatifs à la racine du projet) # --------------------------------------------------------------------------- PROJECT_ROOT = Path(__file__).parent.parent # Mapping : nom affiché -> chemin relatif au projet MODEL_DIRS = { "ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy", "CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy", "CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy", } # Colonnes de métriques affichées dans le tableau COLONNES = [ ("type", "Type", "{:<28}"), ("symbol", "Symbol", "{:<8}"), ("timeframe", "TF", "{:<5}"), ("model_type", "Modèle", "{:<12}"), ("n_samples", "Échantillons", "{:>12}"), ("wf_accuracy", "WF Accuracy", "{:>12}"), ("wf_precision", "WF Precision", "{:>13}"), ("trained_at", "Entraîné le", "{:<20}"), ] # --------------------------------------------------------------------------- # Lecture des méta-fichiers JSON # --------------------------------------------------------------------------- def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]: """ Parcourt les répertoires de modèles et charge les métadonnées JSON. Retourne une liste de dicts prêts pour l'affichage. """ tous = [] for type_label, dossier in model_dirs.items(): if not dossier.exists(): # Répertoire absent — pas encore de modèles entraînés pour ce type continue fichiers_meta = sorted(dossier.glob("*_meta.json")) if not fichiers_meta: continue for f in fichiers_meta: try: with open(f, encoding="utf-8") as fp: meta = json.load(fp) except (json.JSONDecodeError, OSError) as e: print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr) continue wf = meta.get("wf_metrics", {}) # Extraction des champs — avec valeurs par défaut # Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé # comme fallback si les champs ne sont pas dans le JSON. stem_parts = f.stem.replace("_meta", "").split("_") symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?") timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?") model_type = meta.get( "model_type", stem_parts[2] if len(stem_parts) > 2 else dossier.name ) wf_accuracy = wf.get("avg_accuracy", None) wf_precision = wf.get("avg_precision", None) trained_at = meta.get("trained_at", "?") # Tronque la date ISO à 19 caractères pour la lisibilité if isinstance(trained_at, str) and len(trained_at) > 19: trained_at = trained_at[:19] tous.append({ "type": type_label, "symbol": symbol, "timeframe": timeframe, "model_type": model_type, "n_samples": meta.get("n_samples", 0), "wf_accuracy": wf_accuracy, "wf_precision": wf_precision, "trained_at": trained_at, "_meta_path": str(f), # pour le rapport détaillé éventuel "_raw_meta": meta, }) return tous # --------------------------------------------------------------------------- # Affichage du tableau # --------------------------------------------------------------------------- def formater_val(val, fmt: str) -> str: """Formate une valeur numérique ou retourne 'N/A'.""" if val is None: return "N/A" try: if "%" in fmt or "f" in fmt: return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val)) return fmt.format(val) except (ValueError, TypeError): return str(val) def afficher_tableau(modeles: list[dict]) -> None: """Affiche le tableau comparatif des modèles dans le terminal.""" if not modeles: print("Aucun modèle entraîné trouvé.") print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).") return # Entêtes headers = [label for _, label, _ in COLONNES] col_fmts = [fmt for _, _, fmt in COLONNES] # Construction des lignes rows = [] for m in modeles: ligne = [] for key, _, fmt in COLONNES: val = m.get(key) if val is None: ligne.append("N/A") elif key in ("wf_accuracy", "wf_precision"): # Affichage en pourcentage try: ligne.append(f"{float(val):.2%}") except (ValueError, TypeError): ligne.append("N/A") elif key == "n_samples": try: ligne.append(f"{int(val):,}") except (ValueError, TypeError): ligne.append(str(val)) else: ligne.append(str(val)) rows.append(ligne) print() print("=" * 80) print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS") print("=" * 80) try: from tabulate import tabulate print(tabulate(rows, headers=headers, tablefmt="rounded_outline")) except ImportError: # Fallback : affichage simple col_widths = [ max(len(str(h)), max((len(str(r[i])) for r in rows), default=0)) for i, h in enumerate(headers) ] sep = " ".join("-" * w for w in col_widths) header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths)) print(header_line) print(sep) for row in rows: print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths))) print() # --------------------------------------------------------------------------- # Identification du meilleur modèle # --------------------------------------------------------------------------- def identifier_meilleur(modeles: list[dict]) -> None: """Indique quel modèle est le meilleur selon wf_accuracy et wf_precision.""" candidats = [m for m in modeles if m.get("wf_accuracy") is not None] if not candidats: print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).") return # Critère principal : wf_accuracy ; critère secondaire : wf_precision meilleur = max( candidats, key=lambda m: ( m.get("wf_accuracy") or 0.0, m.get("wf_precision") or 0.0, ), ) print(f"Meilleur modèle (WF Accuracy) :") print(f" Type : {meilleur['type']}") print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}") print(f" Modèle : {meilleur['model_type']}") acc = meilleur.get("wf_accuracy") prec = meilleur.get("wf_precision") print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A") print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A") print(f" Entraîné le : {meilleur.get('trained_at', '?')}") print(f" Fichier meta : {meilleur['_meta_path']}") # Avertissement si la précision est insuffisante pour le trading if acc is not None and acc < 0.40: print() print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.") print(" Envisagez un re-entraînement avec plus de données ou de features.") print() # --------------------------------------------------------------------------- # Statistiques globales par type de modèle # --------------------------------------------------------------------------- def afficher_stats_globales(modeles: list[dict]) -> None: """Affiche des statistiques agrégées par type de modèle.""" if not modeles: return # Regroupement par type par_type: dict[str, list] = {} for m in modeles: par_type.setdefault(m["type"], []).append(m) print("Résumé par type de modèle :") print("-" * 50) for type_label, groupe in par_type.items(): accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None] precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None] n_tot = sum(m.get("n_samples", 0) for m in groupe) acc_moy = sum(accs) / len(accs) if accs else None prec_moy = sum(precs) / len(precs) if precs else None acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A" prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A" print(f" {type_label}") print(f" Modèles entraînés : {len(groupe)}") print(f" WF Accuracy moy. : {acc_str}") print(f" WF Precision moy. : {prec_str}") print(f" Échantillons tot. : {n_tot:,}") print() # --------------------------------------------------------------------------- # Point d'entrée principal # --------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: """Parse les arguments en ligne de commande.""" parser = argparse.ArgumentParser( description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)." ) parser.add_argument( "--models-root", type=Path, default=PROJECT_ROOT / "models", help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})", ) return parser.parse_args() def main() -> None: args = parse_args() # Mise à jour des chemins si --models-root est spécifié dirs_effectifs = { label: args.models_root / dossier.name for label, dossier in MODEL_DIRS.items() } print() print("Lecture des modèles dans :") for label, chemin in dirs_effectifs.items(): existe = "OK" if chemin.exists() else "absent" print(f" [{existe}] {chemin}") print() # Chargement de tous les modèles modeles = charger_modeles(dirs_effectifs) if not modeles: print("Aucun modèle trouvé dans les répertoires indiqués.") print() print("Pour entraîner un modèle XGBoost/LightGBM :") print(" POST http://localhost:8100/trading/train") print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}") print() print("Pour entraîner un modèle CNN 1D :") print(" POST http://localhost:8100/trading/train-cnn") print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}") sys.exit(0) # Affichage du tableau afficher_tableau(modeles) # Statistiques par type afficher_stats_globales(modeles) # Meilleur modèle identifier_meilleur(modeles) print(f"Total : {len(modeles)} modèle(s) trouvé(s).") print() if __name__ == "__main__": main()