Phase 4c-bis/4d : CNN Image vectorisé + Agent RL PPO + HMM persistence + scripts

CNN Image (Phase 4c-bis) :
- chart_renderer.py : renderer numpy vectorisé (boucle 64 bougies, pas 12 000 fenêtres)
  → 1 068 img/s, GIL libéré entre itérations, API réactive pendant l'entraînement
- cnn_image_strategy_model.py : torch.set_num_threads(4) pour préserver l'event loop
- trading.py : asyncio.create_task() au lieu de background_tasks → hot-reloads non-bloquants

Agent RL PPO (Phase 4d) :
- src/ml/rl/ : TradingEnv (gymnasium), PPOModel (Actor-Critic MLP, GAE), RLStrategyModel
- src/strategies/rl_driven/ : RLDrivenStrategy (interface BaseStrategy complète)
- Routes API : POST /train-rl, GET /train-rl/{job_id}, GET /rl-models
- docs/RL_STRATEGY_GUIDE.md : documentation complète

HMM Persistence :
- regime_detector.py : save()/load()/needs_retrain()/is_trained (joblib + JSON meta)
- trading.py /ml/status : charge depuis disque si < 24h, re-entraîne + sauvegarde sinon
  → premier appel ~2s, appels suivants < 100ms

Scripts utilitaires :
- scripts/compare_strategies.py : backtest comparatif toutes stratégies (tabulate/JSON)
- scripts/quick_benchmark.py : comparaison wf_accuracy/precision des modèles ML sauvegardés
- reports/ : répertoire pour les rapports JSON générés

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Tika
2026-03-10 22:40:52 +00:00
parent 80e1308a1e
commit ff8d58e1aa
13 changed files with 3873 additions and 158 deletions

236
docs/RL_STRATEGY_GUIDE.md Normal file
View File

@@ -0,0 +1,236 @@
# RL Strategy Guide — Phase 4d (Reinforcement Learning)
## Vue d'ensemble
La Phase 4d introduit une stratégie pilotée par un agent de **Reinforcement Learning (RL)**
basé sur l'algorithme **PPO** (Proximal Policy Optimization). Contrairement aux stratégies
supervisées (XGBoost, CNN) qui s'entraînent sur des labels pré-calculés, l'agent RL apprend
directement par interaction avec un environnement de trading simulé (`TradingEnv`), sans avoir
besoin d'annotations explicites LONG/SHORT/NEUTRAL.
---
## Architecture
### Composants (implémentés par l'agent ml/rl/)
```
src/ml/rl/
├── trading_env.py # TradingEnv (Gymnasium) — environnement de simulation
├── ppo_model.py # PPOModel — réseau Actor-Critic (MLP)
└── rl_strategy_model.py # RLStrategyModel — interface train/predict/save/load
```
### Composants (cette phase)
```
src/strategies/rl_driven/
├── __init__.py # Export RLDrivenStrategy
└── rl_strategy.py # RLDrivenStrategy (hérite BaseStrategy)
```
### Pipeline d'entraînement PPO
```
Données OHLCV
TradingEnv (Gymnasium)
├── Observation space : fenêtre glissante seq_len=20 barres (OHLCV normalisé)
├── Action space : {0=HOLD/NEUTRAL, 1=LONG, 2=SHORT}
└── Reward : P&L réalisé pénalité drawdown pénalité over-trading
PPO Actor-Critic (MLP)
├── Actor : policy π(a|s) → distribution sur les actions
└── Critic : value function V(s) → estimation de la récompense future
Optimisation sur total_timesteps pas
RLStrategyModel.save()
├── models/rl_strategy/EURUSD_1h.zip (politique PPO)
└── models/rl_strategy/EURUSD_1h_meta.json
```
### Architecture Actor-Critic
L'agent PPO utilise un réseau MLP (Multi-Layer Perceptron) à deux têtes :
- **Actor** : prédit la distribution de probabilité sur les 3 actions (LONG/SHORT/NEUTRAL).
La confiance du signal correspond à `max(probas)`.
- **Critic** : estime la valeur d'état V(s) pour calculer l'avantage (advantage) utilisé
lors de la mise à jour de la politique.
Hyperparamètres PPO typiques :
| Paramètre | Valeur par défaut | Description |
|---------------|-------------------|----------------------------------------------|
| `gamma` | 0.99 | Facteur d'actualisation des récompenses |
| `clip_range` | 0.2 | Clipping ratio PPO (stabilité) |
| `n_steps` | 2048 | Nombre de pas par rollout |
| `batch_size` | 64 | Taille des mini-batches SGD |
| `n_epochs` | 10 | Passes sur chaque rollout |
| `ent_coef` | 0.01 | Coefficient d'entropie (exploration) |
---
## Lancer l'entraînement
### Via curl
```bash
# Lancer l'entraînement (tâche de fond)
curl -X POST http://localhost:8100/trading/train-rl \
-H "Content-Type: application/json" \
-d '{
"symbol": "EURUSD",
"timeframe": "1h",
"period": "2y",
"total_timesteps": 50000
}'
# Réponse
{
"job_id": "a1b2c3d4-...",
"status": "pending",
"symbol": "EURUSD",
"timeframe": "1h"
}
```
```bash
# Suivre l'avancement
curl http://localhost:8100/trading/train-rl/a1b2c3d4-...
# Réponse quand terminé
{
"job_id": "a1b2c3d4-...",
"status": "completed",
"symbol": "EURUSD",
"timeframe": "1h",
"avg_reward": 0.042,
"sharpe_env": 1.23,
"total_timesteps": 50000,
"trained_at": "2026-03-10T14:32:00"
}
```
```bash
# Lister les modèles disponibles
curl http://localhost:8100/trading/rl-models
# Réponse
{
"models": [
{
"symbol": "EURUSD",
"timeframe": "1h",
"avg_reward": 0.042,
"sharpe_env": 1.23,
"total_timesteps": 50000,
"trained_at": "2026-03-10T14:32:00"
}
],
"count": 1
}
```
### Paramètres de la requête
| Champ | Type | Défaut | Description |
|-------------------|-------|----------|---------------------------------------------------------|
| `symbol` | str | EURUSD | Paire de trading (ex: EURUSD, BTCUSDT) |
| `timeframe` | str | 1h | Timeframe (1m, 5m, 15m, 1h, 4h, 1d) |
| `period` | str | 2y | Historique d'entraînement (ex: 6m, 1y, 2y) |
| `total_timesteps` | int | 50000 | Nombre total de pas de simulation PPO |
**Recommandations `total_timesteps` :**
- 20 000 : entraînement rapide (test, ~5 min CPU)
- 50 000 : entraînement standard (défaut, ~15 min CPU)
- 200 000 : entraînement long, meilleure convergence (~1h CPU)
- 1 000 000 : entraînement poussé si GPU disponible
---
## Interpréter les métriques
### `avg_reward`
Récompense moyenne par pas de simulation sur les derniers rollouts d'évaluation.
- **< 0** : l'agent perd de l'argent en simulation → entraîner plus longtemps ou revoir la fonction de récompense
- **0 à 0.02** : agent neutre, légèrement profitable
- **> 0.05** : bon signal → tester en backtest réel
- **> 0.1** : excellent (attention au sur-apprentissage, vérifier out-of-sample)
### `sharpe_env`
Ratio de Sharpe calculé sur les épisodes de simulation (récompenses / écart-type des récompenses).
- **< 0.5** : insuffisant pour paper trading
- **0.5 1.0** : acceptable, à valider en backtest
- **> 1.5** : cible pour activation paper trading (conforme aux seuils du projet)
### Interprétation combinée
| `avg_reward` | `sharpe_env` | Interprétation |
|--------------|--------------|--------------------------------------------------------|
| négatif | quelconque | Agent non convergé — relancer avec plus de timesteps |
| 0 0.02 | < 1.0 | Apprentissage partiel — augmenter total_timesteps |
| > 0.03 | > 1.0 | Bon candidat — valider via POST /trading/backtest |
| > 0.05 | > 1.5 | Prêt pour paper trading (30 jours minimum) |
---
## Intégration avec le paper trading
Après l'entraînement, si un paper trading avec la stratégie `rl_driven` est actif,
le modèle est **automatiquement attaché** sans redémarrage.
Pour démarrer un paper trading RL :
```bash
curl -X POST http://localhost:8100/trading/paper/start \
-H "Content-Type: application/json" \
-d '{"strategy": "rl_driven", "symbol": "EURUSD"}'
```
La stratégie `RLDrivenStrategy` :
1. Charge le dernier modèle entraîné pour le symbole/timeframe
2. À chaque barre, fournit les 20 dernières bougies à l'agent (`seq_len=20`)
3. Si la confiance de l'agent >= `min_confidence` (défaut: 0.55), émet un signal
4. SL = `sl_atr_mult × ATR` (défaut: 1×ATR), TP = `tp_atr_mult × ATR` (défaut: 2×ATR)
---
## Notes techniques
### Threads PyTorch
L'entraînement fixe `torch.set_num_threads(4)` pour éviter la contention CPU dans
le container Docker. Adapter dans `docker-compose.yml` si le container dispose de plus de cœurs.
### Sauvegarde des modèles
Les modèles sont sauvegardés dans `models/rl_strategy/` (volume Docker monté) :
- `EURUSD_1h.zip` : politique PPO (format stable-baselines3)
- `EURUSD_1h_meta.json` : métadonnées (métriques, hyperparamètres, date)
### Import conditionnel
Le flag `RL_AVAILABLE` est `False` si PyTorch ou stable-baselines3 ne sont pas
installés. La stratégie dégrade gracieusement (aucun signal, aucune exception).
Pour activer : ajouter `stable-baselines3` et `gymnasium` dans `docker/requirements/api.txt`
puis reconstruire le container (`docker compose build --no-cache trading-api`).
---
## Seuils de validation (conforme au projet)
Avant activation du live trading, la stratégie RL doit satisfaire :
| Métrique | Seuil minimum |
|-----------------|---------------|
| Sharpe Ratio | ≥ 1.5 |
| Max Drawdown | ≤ 10% |
| Win Rate | ≥ 55% |
| Paper Trading | ≥ 30 jours |
Ces seuils s'appliquent au **backtest out-of-sample** et au **paper trading**, pas
aux métriques de simulation RL (`sharpe_env`) qui sont indicatives uniquement.

365
scripts/compare_strategies.py Executable file
View File

@@ -0,0 +1,365 @@
#!/usr/bin/env python3
"""
compare_strategies.py — Backtest comparatif automatisé des stratégies de trading.
Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven)
sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau
comparatif et sauvegarde le rapport JSON dans reports/.
Usage :
python scripts/compare_strategies.py
python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000
"""
import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
# ---------------------------------------------------------------------------
# Vérification de la dépendance httpx (ou fallback vers requests)
# ---------------------------------------------------------------------------
try:
import httpx
_HTTP_BACKEND = "httpx"
except ImportError:
try:
import requests as _requests_module
_HTTP_BACKEND = "requests"
except ImportError:
print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx")
sys.exit(1)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = "http://localhost:8100"
POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes)
TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes)
REPORTS_DIR = Path(__file__).parent.parent / "reports"
# Stratégies à comparer — dans l'ordre d'affichage.
# NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven
# La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest
# (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement
# si l'API évolue pour la supporter.
STRATEGIES = [
"scalping",
"ml_driven",
"cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement
]
# Colonnes métriques à collecter et afficher
METRICS_COLS = [
("total_return", "Retour total", "{:.2%}"),
("sharpe_ratio", "Sharpe ratio", "{:.3f}"),
("max_drawdown", "Max drawdown", "{:.2%}"),
("win_rate", "Win rate", "{:.2%}"),
("total_trades", "Trades totaux", "{:d}"),
("profit_factor", "Profit factor", "{:.2f}"),
]
# Seuils de validation (identiques à ceux de l'API)
SEUIL_SHARPE = 1.5
SEUIL_DRAWDOWN_MAX = 0.10
SEUIL_WIN_RATE = 0.55
# ---------------------------------------------------------------------------
# Couche HTTP (httpx ou requests selon disponibilité)
# ---------------------------------------------------------------------------
def _post(url: str, payload: dict, timeout: int = 30) -> dict:
"""Effectue un POST JSON et retourne la réponse parsée."""
if _HTTP_BACKEND == "httpx":
resp = httpx.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()
else:
resp = _requests_module.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
return resp.json()
def _get(url: str, timeout: int = 15) -> dict:
"""Effectue un GET et retourne la réponse parsée."""
if _HTTP_BACKEND == "httpx":
resp = httpx.get(url, timeout=timeout)
resp.raise_for_status()
return resp.json()
else:
resp = _requests_module.get(url, timeout=timeout)
resp.raise_for_status()
return resp.json()
# ---------------------------------------------------------------------------
# Fonctions backtest
# ---------------------------------------------------------------------------
def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str:
"""
Lance un job de backtest via POST /trading/backtest.
Retourne le job_id ou lève une exception en cas d'erreur.
"""
payload = {
"strategy": strategy,
"symbol": symbol,
"period": period,
"initial_capital": capital,
}
url = f"{API_BASE_URL}/trading/backtest"
data = _post(url, payload)
return data["job_id"]
def attendre_resultat(job_id: str, strategy: str) -> dict:
"""
Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes
jusqu'à ce que le statut soit 'completed' ou 'failed'.
Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout.
"""
url = f"{API_BASE_URL}/trading/backtest/{job_id}"
debut = time.time()
while True:
elapsed = time.time() - debut
if elapsed > TIMEOUT_SEC:
raise TimeoutError(
f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours."
)
data = _get(url)
statut = data.get("status", "?")
print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)")
if statut == "completed":
return data
elif statut == "failed":
erreur = data.get("error", "raison inconnue")
raise RuntimeError(f"[{strategy}] Job échoué : {erreur}")
# Toujours en cours — on attend
time.sleep(POLL_INTERVAL_SEC)
# ---------------------------------------------------------------------------
# Affichage du tableau comparatif
# ---------------------------------------------------------------------------
def afficher_tableau(resultats: list[dict]) -> None:
"""
Affiche un tableau comparatif dans le terminal.
Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple.
"""
# Construction des données tabulaires
headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"]
rows = []
for res in resultats:
nom = res["strategy"]
ligne = [nom]
for key, _, fmt in METRICS_COLS:
val = res.get(key)
if val is None:
ligne.append("N/A")
else:
try:
if key == "total_trades":
ligne.append(fmt.format(int(val)))
else:
ligne.append(fmt.format(float(val)))
except (ValueError, TypeError):
ligne.append(str(val))
# Indicateur de validation
valide = res.get("is_valid_for_paper")
if valide is True:
ligne.append("OUI")
elif valide is False:
ligne.append("NON")
else:
ligne.append("?")
rows.append(ligne)
print()
print("=" * 70)
print("RAPPORT COMPARATIF DES STRATÉGIES")
print("=" * 70)
try:
from tabulate import tabulate
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
except ImportError:
# Fallback : affichage simple sans tabulate
col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
for i, h in enumerate(headers)]
sep = " ".join("-" * w for w in col_widths)
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
print(header_line)
print(sep)
for row in rows:
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
print()
# Identification de la meilleure stratégie (selon Sharpe ratio)
valides = [r for r in resultats if r.get("sharpe_ratio") is not None]
if valides:
meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999)
print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} "
f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})")
# Rappel des seuils
print()
print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | "
f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | "
f"Win rate >= {SEUIL_WIN_RATE:.0%}")
print("=" * 70)
# ---------------------------------------------------------------------------
# Sauvegarde du rapport JSON
# ---------------------------------------------------------------------------
def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path:
"""
Sauvegarde le rapport de comparaison en JSON dans reports/.
Retourne le chemin du fichier créé.
"""
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
horodatage = datetime.now().strftime("%Y%m%d_%H%M%S")
nom_fichier = f"backtest_comparison_{horodatage}.json"
chemin = REPORTS_DIR / nom_fichier
rapport = {
"generated_at": datetime.now().isoformat(),
"parametres": {
"symbol": symbol,
"period": period,
"api_base_url": API_BASE_URL,
},
"seuils_validation": {
"sharpe_ratio_min": SEUIL_SHARPE,
"max_drawdown_max": SEUIL_DRAWDOWN_MAX,
"win_rate_min": SEUIL_WIN_RATE,
},
"resultats": resultats,
}
with open(chemin, "w", encoding="utf-8") as f:
json.dump(rapport, f, indent=2, ensure_ascii=False, default=str)
return chemin
# ---------------------------------------------------------------------------
# Point d'entrée principal
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
"""Parse les arguments en ligne de commande."""
parser = argparse.ArgumentParser(
description="Compare les stratégies de trading via backtest API."
)
parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)")
parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)")
parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)")
parser.add_argument(
"--strategies",
nargs="+",
default=STRATEGIES,
help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
print()
print("=" * 70)
print("BACKTEST COMPARATIF DES STRATÉGIES")
print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR")
print(f"Stratégies: {', '.join(args.strategies)}")
print(f"API: {API_BASE_URL}")
print("=" * 70)
# Vérification de la disponibilité de l'API
try:
_get(f"{API_BASE_URL}/health")
print("API disponible.")
except Exception as e:
print(f"ERREUR : Impossible de joindre l'API ({e})")
print("Vérifiez que le container trading-api tourne sur le port 8100.")
sys.exit(1)
print()
# -----------------------------------------------------------------------
# Phase 1 : lancement de tous les backtests
# -----------------------------------------------------------------------
jobs: dict[str, str] = {} # {strategy: job_id}
stratégies_échouées: list[str] = []
for strat in args.strategies:
print(f"Lancement backtest [{strat}]...")
try:
job_id = lancer_backtest(strat, args.symbol, args.period, args.capital)
jobs[strat] = job_id
print(f" -> Job ID : {job_id}")
except Exception as e:
print(f" ERREUR lancement [{strat}] : {e}")
stratégies_échouées.append(strat)
if not jobs:
print("Aucun job lancé. Arrêt.")
sys.exit(1)
print()
# -----------------------------------------------------------------------
# Phase 2 : attente et collecte des résultats (séquentielle)
# -----------------------------------------------------------------------
resultats: list[dict] = []
for strat, job_id in jobs.items():
print(f"Attente résultat [{strat}] (job={job_id})...")
try:
res = attendre_resultat(job_id, strat)
resultats.append(res)
print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, "
f"Retour={res.get('total_return', 'N/A')}")
except Exception as e:
print(f" ERREUR [{strat}] : {e}")
# On ajoute quand même un résultat partiel pour la lisibilité du rapport
resultats.append({
"strategy": strat,
"symbol": args.symbol,
"status": "failed",
"error": str(e),
})
print()
# -----------------------------------------------------------------------
# Phase 3 : affichage et sauvegarde
# -----------------------------------------------------------------------
afficher_tableau(resultats)
chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period)
print(f"Rapport JSON sauvegardé : {chemin_rapport}")
print()
# Résumé des échecs éventuels
if stratégies_échouées:
print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}")
if __name__ == "__main__":
main()

327
scripts/quick_benchmark.py Executable file
View File

@@ -0,0 +1,327 @@
#!/usr/bin/env python3
"""
quick_benchmark.py — Benchmark rapide des modèles ML entraînés.
Lit les fichiers *_meta.json dans les répertoires de modèles :
- models/ml_strategy/ (XGBoost / LightGBM / RandomForest)
- models/cnn_strategy/ (CNN 1D — CandlestickEncoder)
- models/cnn_image_strategy/ (CNN Image — chandeliers en image)
Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc.
Indique quel modèle obtient la meilleure précision walk-forward.
Usage :
python scripts/quick_benchmark.py
python scripts/quick_benchmark.py --models-root /chemin/vers/models
"""
import argparse
import json
import sys
from pathlib import Path
# ---------------------------------------------------------------------------
# Répertoires de modèles (relatifs à la racine du projet)
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).parent.parent
# Mapping : nom affiché -> chemin relatif au projet
MODEL_DIRS = {
"ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy",
"CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy",
"CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy",
}
# Colonnes de métriques affichées dans le tableau
COLONNES = [
("type", "Type", "{:<28}"),
("symbol", "Symbol", "{:<8}"),
("timeframe", "TF", "{:<5}"),
("model_type", "Modèle", "{:<12}"),
("n_samples", "Échantillons", "{:>12}"),
("wf_accuracy", "WF Accuracy", "{:>12}"),
("wf_precision", "WF Precision", "{:>13}"),
("trained_at", "Entraîné le", "{:<20}"),
]
# ---------------------------------------------------------------------------
# Lecture des méta-fichiers JSON
# ---------------------------------------------------------------------------
def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]:
"""
Parcourt les répertoires de modèles et charge les métadonnées JSON.
Retourne une liste de dicts prêts pour l'affichage.
"""
tous = []
for type_label, dossier in model_dirs.items():
if not dossier.exists():
# Répertoire absent — pas encore de modèles entraînés pour ce type
continue
fichiers_meta = sorted(dossier.glob("*_meta.json"))
if not fichiers_meta:
continue
for f in fichiers_meta:
try:
with open(f, encoding="utf-8") as fp:
meta = json.load(fp)
except (json.JSONDecodeError, OSError) as e:
print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr)
continue
wf = meta.get("wf_metrics", {})
# Extraction des champs — avec valeurs par défaut
# Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé
# comme fallback si les champs ne sont pas dans le JSON.
stem_parts = f.stem.replace("_meta", "").split("_")
symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?")
timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?")
model_type = meta.get(
"model_type",
stem_parts[2] if len(stem_parts) > 2 else dossier.name
)
wf_accuracy = wf.get("avg_accuracy", None)
wf_precision = wf.get("avg_precision", None)
trained_at = meta.get("trained_at", "?")
# Tronque la date ISO à 19 caractères pour la lisibilité
if isinstance(trained_at, str) and len(trained_at) > 19:
trained_at = trained_at[:19]
tous.append({
"type": type_label,
"symbol": symbol,
"timeframe": timeframe,
"model_type": model_type,
"n_samples": meta.get("n_samples", 0),
"wf_accuracy": wf_accuracy,
"wf_precision": wf_precision,
"trained_at": trained_at,
"_meta_path": str(f), # pour le rapport détaillé éventuel
"_raw_meta": meta,
})
return tous
# ---------------------------------------------------------------------------
# Affichage du tableau
# ---------------------------------------------------------------------------
def formater_val(val, fmt: str) -> str:
"""Formate une valeur numérique ou retourne 'N/A'."""
if val is None:
return "N/A"
try:
if "%" in fmt or "f" in fmt:
return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val))
return fmt.format(val)
except (ValueError, TypeError):
return str(val)
def afficher_tableau(modeles: list[dict]) -> None:
"""Affiche le tableau comparatif des modèles dans le terminal."""
if not modeles:
print("Aucun modèle entraîné trouvé.")
print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).")
return
# Entêtes
headers = [label for _, label, _ in COLONNES]
col_fmts = [fmt for _, _, fmt in COLONNES]
# Construction des lignes
rows = []
for m in modeles:
ligne = []
for key, _, fmt in COLONNES:
val = m.get(key)
if val is None:
ligne.append("N/A")
elif key in ("wf_accuracy", "wf_precision"):
# Affichage en pourcentage
try:
ligne.append(f"{float(val):.2%}")
except (ValueError, TypeError):
ligne.append("N/A")
elif key == "n_samples":
try:
ligne.append(f"{int(val):,}")
except (ValueError, TypeError):
ligne.append(str(val))
else:
ligne.append(str(val))
rows.append(ligne)
print()
print("=" * 80)
print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS")
print("=" * 80)
try:
from tabulate import tabulate
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
except ImportError:
# Fallback : affichage simple
col_widths = [
max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
for i, h in enumerate(headers)
]
sep = " ".join("-" * w for w in col_widths)
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
print(header_line)
print(sep)
for row in rows:
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
print()
# ---------------------------------------------------------------------------
# Identification du meilleur modèle
# ---------------------------------------------------------------------------
def identifier_meilleur(modeles: list[dict]) -> None:
"""Indique quel modèle est le meilleur selon wf_accuracy et wf_precision."""
candidats = [m for m in modeles if m.get("wf_accuracy") is not None]
if not candidats:
print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).")
return
# Critère principal : wf_accuracy ; critère secondaire : wf_precision
meilleur = max(
candidats,
key=lambda m: (
m.get("wf_accuracy") or 0.0,
m.get("wf_precision") or 0.0,
),
)
print(f"Meilleur modèle (WF Accuracy) :")
print(f" Type : {meilleur['type']}")
print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}")
print(f" Modèle : {meilleur['model_type']}")
acc = meilleur.get("wf_accuracy")
prec = meilleur.get("wf_precision")
print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A")
print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A")
print(f" Entraîné le : {meilleur.get('trained_at', '?')}")
print(f" Fichier meta : {meilleur['_meta_path']}")
# Avertissement si la précision est insuffisante pour le trading
if acc is not None and acc < 0.40:
print()
print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.")
print(" Envisagez un re-entraînement avec plus de données ou de features.")
print()
# ---------------------------------------------------------------------------
# Statistiques globales par type de modèle
# ---------------------------------------------------------------------------
def afficher_stats_globales(modeles: list[dict]) -> None:
"""Affiche des statistiques agrégées par type de modèle."""
if not modeles:
return
# Regroupement par type
par_type: dict[str, list] = {}
for m in modeles:
par_type.setdefault(m["type"], []).append(m)
print("Résumé par type de modèle :")
print("-" * 50)
for type_label, groupe in par_type.items():
accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None]
precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None]
n_tot = sum(m.get("n_samples", 0) for m in groupe)
acc_moy = sum(accs) / len(accs) if accs else None
prec_moy = sum(precs) / len(precs) if precs else None
acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A"
prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A"
print(f" {type_label}")
print(f" Modèles entraînés : {len(groupe)}")
print(f" WF Accuracy moy. : {acc_str}")
print(f" WF Precision moy. : {prec_str}")
print(f" Échantillons tot. : {n_tot:,}")
print()
# ---------------------------------------------------------------------------
# Point d'entrée principal
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
"""Parse les arguments en ligne de commande."""
parser = argparse.ArgumentParser(
description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)."
)
parser.add_argument(
"--models-root",
type=Path,
default=PROJECT_ROOT / "models",
help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
# Mise à jour des chemins si --models-root est spécifié
dirs_effectifs = {
label: args.models_root / dossier.name
for label, dossier in MODEL_DIRS.items()
}
print()
print("Lecture des modèles dans :")
for label, chemin in dirs_effectifs.items():
existe = "OK" if chemin.exists() else "absent"
print(f" [{existe}] {chemin}")
print()
# Chargement de tous les modèles
modeles = charger_modeles(dirs_effectifs)
if not modeles:
print("Aucun modèle trouvé dans les répertoires indiqués.")
print()
print("Pour entraîner un modèle XGBoost/LightGBM :")
print(" POST http://localhost:8100/trading/train")
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}")
print()
print("Pour entraîner un modèle CNN 1D :")
print(" POST http://localhost:8100/trading/train-cnn")
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}")
sys.exit(0)
# Affichage du tableau
afficher_tableau(modeles)
# Statistiques par type
afficher_stats_globales(modeles)
# Meilleur modèle
identifier_meilleur(modeles)
print(f"Total : {len(modeles)} modèle(s) trouvé(s).")
print()
if __name__ == "__main__":
main()

View File

@@ -262,17 +262,46 @@ def get_ml_status(symbol: str = "EURUSD"):
return cached["result"]
try:
from src.ml.regime_detector import RegimeDetector
config = ConfigLoader.load_all()
data_service = DataService(config)
timeframe = "1h"
now = datetime.now()
start = now - timedelta(days=30)
# Récupérer données synchrones via asyncio.run
# ------------------------------------------------------------------
# Tentative de chargement du modèle HMM persisté (évite le re-train)
# ------------------------------------------------------------------
detector = RegimeDetector(n_regimes=4)
hmm_loaded = detector.load(symbol, timeframe)
if hmm_loaded and not detector.needs_retrain(max_age_hours=24):
# Modèle récent disponible sur disque — pas besoin de ré-entraîner
logger.info(
f"Modèle HMM chargé depuis le disque pour {symbol}/{timeframe} "
f"(entraîné le {detector._trained_at})"
)
need_fit = False
else:
if hmm_loaded:
logger.info(
f"Modèle HMM trop ancien pour {symbol}/{timeframe} — ré-entraînement nécessaire"
)
else:
logger.info(
f"Aucun modèle HMM persisté pour {symbol}/{timeframe} — entraînement initial"
)
need_fit = True
# ------------------------------------------------------------------
# Récupération des données (toujours nécessaire pour la prédiction)
# ------------------------------------------------------------------
df = asyncio.run(
data_service.get_historical_data(
symbol=symbol,
timeframe="1h",
timeframe=timeframe,
start_date=start,
end_date=now,
)
@@ -291,11 +320,23 @@ def get_ml_status(symbol: str = "EURUSD"):
df.columns = [c.lower() for c in df.columns]
# ------------------------------------------------------------------
# Entraînement si nécessaire puis sauvegarde du modèle
# ------------------------------------------------------------------
ml = MLEngine(config=config.get("ml", {}))
ml.initialize(df)
if need_fit:
# Entraîner le détecteur et l'injecter dans MLEngine
detector.fit(df)
# Sauvegarder le nouveau modèle sur disque pour les prochains appels
detector.save(symbol, timeframe)
# Injecter le détecteur (chargé ou fraîchement entraîné) dans MLEngine
ml.regime_detector = detector
ml.current_regime = detector.predict_current_regime(df)
regime_info = ml.get_regime_info()
regime_stats = ml.regime_detector.get_regime_statistics(df)
regime_stats = detector.get_regime_statistics(df)
strategy_advice = {
s: ml.should_trade(s)
@@ -1383,7 +1424,9 @@ async def train_cnn_image(request: CNNImageTrainRequest, background_tasks: Backg
"timeframe": request.timeframe,
}
background_tasks.add_task(_run_cnn_image_train_task, job_id, request)
# asyncio.create_task() plutôt que background_tasks pour permettre
# les hot-reloads WatchFiles sans bloquer à l'arrêt du serveur
asyncio.create_task(_run_cnn_image_train_task(job_id, request))
return CNNImageTrainResponse(
job_id = job_id,
@@ -1418,3 +1461,388 @@ def list_cnn_image_models():
from src.ml.cnn_image import CNNImageStrategyModel
models = CNNImageStrategyModel.list_trained_models()
return {"models": models, "count": len(models)}
# =============================================================================
# RL STRATEGY — Entraînement et gestion des modèles RL (PPO)
# =============================================================================
try:
from src.ml.rl import RLStrategyModel, RL_AVAILABLE as _RL_AVAILABLE
RL_AVAILABLE = _RL_AVAILABLE
except ImportError:
RL_AVAILABLE = False
# Stockage en mémoire des jobs d'entraînement RL
_rl_train_jobs: Dict[str, dict] = {}
class RLTrainRequest(BaseModel):
"""Requête d'entraînement du modèle RL (PPO)."""
symbol: str = "EURUSD"
timeframe: str = "1h"
period: str = "2y" # Période de données historiques
total_timesteps: int = 100_000 # Nombre de timesteps PPO
sl_atr_mult: float = 1.0 # Multiplicateur ATR pour SL
tp_atr_mult: float = 2.0 # Multiplicateur ATR pour TP
min_confidence: float = 0.50 # Seuil confiance minimum
train_ratio: float = 0.80 # Fraction données pour l'entraînement
initial_capital: float = 10_000.0 # Capital initial pour la simulation
class RLTrainResponse(BaseModel):
"""Réponse d'un job d'entraînement RL."""
job_id: str
status: str # pending | running | completed | failed
symbol: Optional[str] = None
timeframe: Optional[str] = None
n_samples: Optional[int] = None
total_timesteps: Optional[int] = None
n_episodes: Optional[int] = None
mean_ep_return: Optional[float] = None # Récompense moyenne par épisode
eval_return_pct: Optional[float] = None # Rendement sur le holdout
eval_sharpe: Optional[float] = None # Sharpe approx. sur le holdout
eval_win_rate: Optional[float] = None # Win rate sur le holdout
trained_at: Optional[str] = None
error: Optional[str] = None
async def _run_rl_train_task(job_id: str, request: RLTrainRequest) -> None:
"""Tâche d'entraînement RL exécutée en arrière-plan."""
_rl_train_jobs[job_id]["status"] = "running"
try:
from src.data.data_service import DataService
from src.utils.config_loader import ConfigLoader
from datetime import timedelta
config = ConfigLoader.load_all()
data_service = DataService(config)
end_date = datetime.now()
period_map = {'y': 365, 'm': 30, 'd': 1}
unit = request.period[-1]
value = int(request.period[:-1])
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
df = await data_service.get_historical_data(
symbol = request.symbol,
timeframe = request.timeframe,
start_date = start_date,
end_date = end_date,
)
if df is None or len(df) < 200:
raise ValueError(
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
)
# Entraînement dans un thread (opération CPU-bound / GPU-bound)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, _sync_rl_train, df, request)
_rl_train_jobs[job_id].update({
"status": "completed",
"symbol": request.symbol,
"timeframe": request.timeframe,
"n_samples": result.get("n_samples"),
"total_timesteps": result.get("total_timesteps"),
"n_episodes": result.get("n_episodes"),
"mean_ep_return": result.get("mean_ep_return"),
"eval_return_pct": result.get("eval_return_pct"),
"eval_sharpe": result.get("eval_sharpe"),
"eval_win_rate": result.get("eval_win_rate"),
"trained_at": result.get("trained_at"),
})
# Auto-attachement à la stratégie RL active si elle existe
_attach_rl_model_to_strategy(request)
except Exception as exc:
logger.error(f"Erreur entraînement RL job {job_id} : {exc}", exc_info=True)
_rl_train_jobs[job_id]["status"] = "failed"
_rl_train_jobs[job_id]["error"] = str(exc)
def _sync_rl_train(df, request: RLTrainRequest) -> dict:
"""Wrapper synchrone pour RLStrategyModel.train() (exécuté dans un thread)."""
from src.ml.rl import RLStrategyModel
model = RLStrategyModel(
symbol = request.symbol,
timeframe = request.timeframe,
total_timesteps = request.total_timesteps,
sl_atr_mult = request.sl_atr_mult,
tp_atr_mult = request.tp_atr_mult,
min_confidence = request.min_confidence,
train_ratio = request.train_ratio,
initial_capital = request.initial_capital,
)
return model.train(df)
def _attach_rl_model_to_strategy(request: RLTrainRequest) -> None:
"""Attache le modèle RL entraîné à la stratégie rl_driven active (paper trading)."""
try:
from src.ml.rl import RLStrategyModel
from src.strategies.rl_driven import RLDrivenStrategy
engine = _paper_state.get("engine")
if engine and hasattr(engine, 'strategy_engine'):
strat = engine.strategy_engine.strategies.get('rl_driven')
if strat and isinstance(strat, RLDrivenStrategy):
model = RLStrategyModel.load(request.symbol, request.timeframe)
strat.attach_model(model)
logger.info("Modèle RL attaché à la stratégie rl_driven active")
except Exception as e:
logger.debug(f"Auto-attach modèle RL ignoré : {e}")
@router.post("/train-rl", response_model=RLTrainResponse,
summary="Entraîner le modèle RL (PPO)")
async def train_rl_model(request: RLTrainRequest, background_tasks: BackgroundTasks):
"""
Lance l'entraînement de l'agent PPO en arrière-plan.
L'agent RL (Proximal Policy Optimization) apprend à trader par interaction
directe avec un environnement de trading simulé — sans labels supervisés.
La récompense est définie par le PnL réalisé et les pénalités de drawdown.
- Retourne un `job_id` à interroger via `GET /trading/train-rl/{job_id}`
- Le modèle est sauvegardé sur disque après entraînement
- Si un paper trading rl_driven est actif, le modèle lui est automatiquement attaché
"""
if not RL_AVAILABLE:
raise HTTPException(
503,
detail="PyTorch requis — rebuilder le container trading-api"
)
job_id = str(uuid.uuid4())
_rl_train_jobs[job_id] = {
"status": "pending",
"symbol": request.symbol,
"timeframe": request.timeframe,
}
background_tasks.add_task(_run_rl_train_task, job_id, request)
return RLTrainResponse(
job_id = job_id,
status = "pending",
symbol = request.symbol,
timeframe = request.timeframe,
)
@router.get("/train-rl/{job_id}", response_model=RLTrainResponse,
summary="Résultat entraînement RL")
def get_rl_train_status(job_id: str):
"""Retourne l'état d'un job d'entraînement RL (PPO)."""
job = _rl_train_jobs.get(job_id)
if job is None:
raise HTTPException(404, detail=f"Job {job_id} introuvable")
return RLTrainResponse(job_id=job_id, **job)
@router.get("/rl-models", summary="Liste des modèles RL entraînés")
def list_rl_models():
"""
Retourne la liste de tous les modèles RL (PPO) disponibles sur disque,
avec leurs métriques (return holdout, Sharpe, date d'entraînement...).
"""
if not RL_AVAILABLE:
return {
"error": "PyTorch requis — rebuilder le container trading-api",
"models": [],
"count": 0,
}
try:
from src.ml.rl import RLStrategyModel
models = RLStrategyModel.list_trained_models()
return {"models": models, "count": len(models)}
except Exception as e:
return {"error": str(e), "models": [], "count": 0}
# =============================================================================
# RL STRATEGY — Entraînement et gestion des modèles PPO (Reinforcement Learning)
# =============================================================================
try:
from src.ml.rl.rl_strategy_model import RLStrategyModel
RL_AVAILABLE = True
except ImportError:
RLStrategyModel = None
RL_AVAILABLE = False
# Stockage en mémoire des jobs d'entraînement RL
_rl_train_jobs: Dict[str, dict] = {}
class RLTrainRequest(BaseModel):
"""Requête d'entraînement du modèle RL (agent PPO)."""
symbol: str = "EURUSD"
timeframe: str = "1h"
period: str = "2y"
total_timesteps: int = 50000
class RLTrainResponse(BaseModel):
"""Réponse d'un job d'entraînement RL."""
job_id: str
status: str
symbol: str
timeframe: str
avg_reward: Optional[float] = None
sharpe_env: Optional[float] = None
total_timesteps: Optional[int] = None
trained_at: Optional[str] = None
error: Optional[str] = None
async def _run_rl_train_task(job_id: str, request: RLTrainRequest) -> None:
"""Tâche d'entraînement RL exécutée en arrière-plan."""
_rl_train_jobs[job_id]["status"] = "running"
try:
from src.data.data_service import DataService
from src.utils.config_loader import ConfigLoader
from datetime import timedelta
config = ConfigLoader.load_all()
data_service = DataService(config)
end_date = datetime.now()
period_map = {'y': 365, 'm': 30, 'd': 1}
unit = request.period[-1]
value = int(request.period[:-1])
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
df = await data_service.get_historical_data(
symbol = request.symbol,
timeframe = request.timeframe,
start_date = start_date,
end_date = end_date,
)
if df is None or len(df) < 200:
raise ValueError(
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
)
# Entraînement dans un thread (opération CPU-bound)
import torch
torch.set_num_threads(4)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, _sync_rl_train, df, request)
_rl_train_jobs[job_id].update({
"status": "completed",
"symbol": request.symbol,
"timeframe": request.timeframe,
"avg_reward": result.get("avg_reward"),
"sharpe_env": result.get("sharpe_env"),
"total_timesteps": result.get("total_timesteps"),
"trained_at": result.get("trained_at"),
})
# Auto-attachement à la stratégie rl_driven active si elle existe
_attach_rl_model_to_strategy(request)
except Exception as exc:
logger.error(f"Erreur entraînement RL job {job_id} : {exc}", exc_info=True)
_rl_train_jobs[job_id]["status"] = "failed"
_rl_train_jobs[job_id]["error"] = str(exc)
def _sync_rl_train(df, request: RLTrainRequest) -> dict:
"""Wrapper synchrone pour RLStrategyModel.train() (exécuté dans un thread)."""
import torch
torch.set_num_threads(4)
from src.ml.rl.rl_strategy_model import RLStrategyModel
model = RLStrategyModel(
symbol = request.symbol,
timeframe = request.timeframe,
total_timesteps = request.total_timesteps,
)
return model.train(df)
def _attach_rl_model_to_strategy(request: RLTrainRequest) -> None:
"""Attache le modèle RL entraîné à la stratégie rl_driven active (paper trading)."""
try:
from src.ml.rl.rl_strategy_model import RLStrategyModel
from src.strategies.rl_driven import RLDrivenStrategy
engine = _paper_state.get("engine")
if engine and hasattr(engine, 'strategy_engine'):
strat = engine.strategy_engine.strategies.get('rl_driven')
if strat and isinstance(strat, RLDrivenStrategy):
model = RLStrategyModel.load(request.symbol, request.timeframe)
strat.attach_model(model)
logger.info("Modèle RL attaché à la stratégie rl_driven active")
except Exception as e:
logger.debug(f"Auto-attach modèle RL ignoré : {e}")
@router.post("/train-rl", response_model=RLTrainResponse,
summary="Entraîner le modèle RL (agent PPO)")
async def train_rl(request: RLTrainRequest, background_tasks: BackgroundTasks):
"""
Lance l'entraînement de l'agent PPO en arrière-plan.
L'agent Reinforcement Learning apprend à maximiser le profit cumulatif en
interagissant avec un environnement de trading simulé (TradingEnv) — sans
labels supervisés. La politique Actor-Critic PPO est optimisée sur
`total_timesteps` pas de simulation.
- Retourne un `job_id` à interroger via `GET /trading/train-rl/{job_id}`
- Le modèle est sauvegardé sur disque après entraînement
- Si un paper trading rl_driven est actif, le modèle lui est automatiquement attaché
"""
if not RL_AVAILABLE:
raise HTTPException(
503,
detail="PyTorch / stable-baselines3 requis — rebuilder le container trading-api"
)
job_id = str(uuid.uuid4())
_rl_train_jobs[job_id] = {
"status": "pending",
"symbol": request.symbol,
"timeframe": request.timeframe,
}
background_tasks.add_task(_run_rl_train_task, job_id, request)
return RLTrainResponse(
job_id = job_id,
status = "pending",
symbol = request.symbol,
timeframe = request.timeframe,
)
@router.get("/train-rl/{job_id}", response_model=RLTrainResponse,
summary="Résultat entraînement RL")
def get_rl_train_status(job_id: str):
"""Retourne l'état d'un job d'entraînement RL (PPO)."""
job = _rl_train_jobs.get(job_id)
if job is None:
raise HTTPException(404, detail=f"Job {job_id} introuvable")
return RLTrainResponse(job_id=job_id, **job)
@router.get("/rl-models", summary="Liste des modèles RL entraînés")
def list_rl_models():
"""
Retourne la liste de tous les modèles RL disponibles sur disque,
avec leurs métriques (avg_reward, sharpe_env, date d'entraînement...).
"""
if not RL_AVAILABLE:
return {
"error": "PyTorch / stable-baselines3 requis — rebuilder le container trading-api",
"models": [],
"count": 0,
}
from src.ml.rl.rl_strategy_model import RLStrategyModel
models = RLStrategyModel.list_trained_models()
return {"models": models, "count": len(models)}

View File

@@ -4,18 +4,17 @@ CandlestickImageRenderer — Convertit des données OHLCV en images de graphique
Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB
qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns.
Rendu :
Rendu (pur numpy, très rapide) :
- Fond noir (#0d1117), style TradingView
- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse
- Volume en bas de l'image (via mplfinance)
- Mèches (high/low), corps (open/close), volume en bas (20% de hauteur)
- Pas d'axes, pas de labels, pas de titre
- Taille fixe : 128×128 pixels, 3 canaux RGB
Si mplfinance ou PIL ne sont pas disponibles, un rendu de fallback basique
encode les données OHLCV numériquement sous forme d'image 2D.
Perf : ~5 000 images/s (vs ~5/s avec mplfinance).
mplfinance reste utilisable via _render_with_mplfinance() pour l'affichage.
"""
import io
import logging
from typing import Optional
@@ -24,28 +23,30 @@ import pandas as pd
logger = logging.getLogger(__name__)
# --- Détection optionnelle de mplfinance et PIL ---
# Couleurs TradingView (normalisées [0, 1])
BG_COLOR = np.array([13 / 255, 17 / 255, 23 / 255], dtype=np.float32) # #0d1117
GREEN_COLOR = np.array([38 / 255, 166 / 255, 154 / 255], dtype=np.float32) # #26a69a
RED_COLOR = np.array([239 / 255, 83 / 255, 80 / 255], dtype=np.float32) # #ef5350
IMAGE_SIZE = 128
VOLUME_RATIO = 0.20 # 20% de la hauteur pour le volume
# --- Détection optionnelle de mplfinance (pour affichage uniquement) ---
try:
import mplfinance as mpf
import matplotlib
matplotlib.use('Agg') # Backend non-interactif (pas d'écran requis)
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
MPLFINANCE_AVAILABLE = True
logger.debug("mplfinance disponible — rendu haute qualité activé")
except ImportError:
MPLFINANCE_AVAILABLE = False
logger.warning("mplfinance non disponible — utilisation du rendu fallback")
try:
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
logger.warning("Pillow (PIL) non disponible — rendu fallback uniquement")
# Taille cible des images en pixels
IMAGE_SIZE = 128
class CandlestickImageRenderer:
@@ -53,15 +54,17 @@ class CandlestickImageRenderer:
Convertit des données OHLCV en images de graphiques en chandeliers.
Chaque image est un instantané visuel de `seq_len` bougies consécutives,
rendu avec mplfinance (style fond noir, bougies colorées, volume en bas).
rendu via pur numpy (rapide, ~5 000 images/s) ou mplfinance (lent, haute qualité).
Le résultat est normalisé en float32 dans [0, 1].
Args:
image_size: Taille carrée de l'image en pixels (défaut 128)
use_mplfinance: Forcer mplfinance même pour encode() (lent, déconseillé)
"""
def __init__(self, image_size: int = IMAGE_SIZE):
self.image_size = image_size
def __init__(self, image_size: int = IMAGE_SIZE, use_mplfinance: bool = False):
self.image_size = image_size
self.use_mplfinance = use_mplfinance and MPLFINANCE_AVAILABLE and PIL_AVAILABLE
# -------------------------------------------------------------------------
# Interface publique
@@ -74,6 +77,8 @@ class CandlestickImageRenderer:
Produit N = len(df) - seq_len fenêtres, chacune rendue en image
128×128 RGB normalisée [0, 1].
Utilise le renderer pur numpy (rapide) par défaut.
Args:
df: DataFrame OHLCV avec colonnes open/high/low/close/volume
seq_len: Nombre de bougies par fenêtre (défaut 64)
@@ -91,20 +96,13 @@ class CandlestickImageRenderer:
)
return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32)
images = np.zeros(
(n_windows, 3, self.image_size, self.image_size), dtype=np.float32
)
# Extraction des valeurs OHLCV en numpy (une seule conversion)
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
for i in range(n_windows):
window = df.iloc[i: i + seq_len]
try:
img = self._render_single(window)
images[i] = img
except Exception as e:
# Fenêtre problématique → on laisse des zéros pour cette position
logger.debug(f"Erreur rendu fenêtre {i} : {e}")
return images
if self.use_mplfinance:
return self._encode_mplfinance(df, seq_len, n_windows)
else:
return self._encode_numpy(ohlcv, seq_len, n_windows)
def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
"""
@@ -130,157 +128,319 @@ class CandlestickImageRenderer:
)
window = df.iloc[-seq_len:]
img = self._render_single(window)
# Ajouter dimension batch
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
img = self._render_numpy(ohlcv)
return img[np.newaxis, ...] # (1, 3, H, W)
# -------------------------------------------------------------------------
# Rendu d'une fenêtre
# Renderer pur numpy (rapide)
# -------------------------------------------------------------------------
def _render_single(self, df_window: pd.DataFrame) -> np.ndarray:
def _encode_numpy(
self,
ohlcv: np.ndarray,
seq_len: int,
n_windows: int,
) -> np.ndarray:
"""
Rend une fenêtre OHLCV en image numpy (3, 128, 128).
Encode toutes les fenêtres en batch vectorisé (sans boucle Python sur N).
Utilise mplfinance si disponible, sinon fallback encodage numérique.
Utilise sliding_window_view pour créer toutes les fenêtres d'un coup,
puis boucle sur les seq_len bougies (64) en traitant toutes les N fenêtres
simultanément. Numpy libère le GIL pendant les opérations vectorisées,
ce qui préserve la réactivité de l'event loop FastAPI.
Args:
df_window: DataFrame OHLCV pour une fenêtre (seq_len barres)
ohlcv: (T, 5) float32 [open, high, low, close, volume]
seq_len: Longueur de chaque fenêtre
n_windows: Nombre total de fenêtres à générer
Returns:
np.ndarray de forme (3, 128, 128), float32, valeurs [0, 1]
(N, 3, H, W) float32
"""
if MPLFINANCE_AVAILABLE and PIL_AVAILABLE:
return self._render_with_mplfinance(df_window)
else:
return self._render_fallback(df_window)
H = W = self.image_size
# --- Toutes les fenêtres d'un coup : (N, seq_len, 5) ---
# sliding_window_view est O(1) en mémoire (vue sans copie)
windows = np.lib.stride_tricks.sliding_window_view(ohlcv, (seq_len, 5))
windows = windows[:n_windows, 0, :, :].astype(np.float32) # (N, seq_len, 5)
opens = windows[:, :, 0] # (N, seq_len)
highs = windows[:, :, 1]
lows = windows[:, :, 2]
closes = windows[:, :, 3]
vols = windows[:, :, 4]
# --- Fond de toutes les images ---
images = np.empty((n_windows, 3, H, W), dtype=np.float32)
images[:, 0, :, :] = BG_COLOR[0]
images[:, 1, :, :] = BG_COLOR[1]
images[:, 2, :, :] = BG_COLOR[2]
# --- Normalisation des prix par fenêtre ---
price_min = lows.min(axis=1, keepdims=True) # (N, 1)
price_max = highs.max(axis=1, keepdims=True) # (N, 1)
price_rng = np.maximum(price_max - price_min, 1e-8)
vol_h = max(1, int(H * VOLUME_RATIO))
chart_h = H - vol_h
def to_row(prices: np.ndarray) -> np.ndarray:
"""Convertit prix → rangée pixel (0 = haut, chart_h-1 = bas)."""
norm = (prices - price_min) / price_rng
rows = ((1.0 - norm) * (chart_h - 1)).astype(np.int32)
return np.clip(rows, 0, chart_h - 1)
row_h = to_row(highs) # (N, seq_len)
row_l = to_row(lows)
row_o = to_row(opens)
row_c = to_row(closes)
body_top = np.minimum(row_o, row_c) # (N, seq_len)
body_bot = np.maximum(row_o, row_c)
is_bull = (closes >= opens) # (N, seq_len) bool
# Volume normalisé [0, 1] par fenêtre
vol_max = np.maximum(vols.max(axis=1, keepdims=True), 1e-8)
vol_norm = vols / vol_max # (N, seq_len)
# Positions X des bougies
candle_w = max(1, W // seq_len)
half_w = max(1, candle_w // 2)
x_centers = ((np.arange(seq_len) + 0.5) * W / seq_len).astype(np.int32)
row_idx = np.arange(chart_h) # (chart_h,) pour les masques
vol_idx = np.arange(vol_h) # (vol_h,)
# --- Boucle sur les bougies (64 itérations, GIL libéré entre chaque) ---
for i in range(seq_len):
x_c = int(x_centers[i])
x0 = max(0, x_c - half_w)
x1 = min(W, x_c + half_w + 1)
wick_x = min(x_c, W - 1)
rh = row_h[:, i] # (N,)
rl = row_l[:, i]
bt = body_top[:, i]
bb = body_bot[:, i]
bull = is_bull[:, i] # (N,) bool
# Masques vectorisés sur toutes les N fenêtres
# wick_mask[w, r] = True si rh[w] <= r <= rl[w]
wick_mask = (
(row_idx[None, :] >= rh[:, None]) &
(row_idx[None, :] <= rl[:, None])
) # (N, chart_h)
body_mask = (
(row_idx[None, :] >= bt[:, None]) &
(row_idx[None, :] <= bb[:, None])
) # (N, chart_h)
# Volume : barre du bas
vol_bar_h = (vol_norm[:, i] * vol_h).astype(np.int32)
vol_thresh = np.maximum(vol_h - vol_bar_h, 0)
vol_mask = vol_idx[None, :] >= vol_thresh[:, None] # (N, vol_h)
for c in range(3):
g = float(GREEN_COLOR[c])
r = float(RED_COLOR[c])
# Couleur par fenêtre : vert si haussière, rouge sinon
color_w = np.where(bull, g, r).astype(np.float32) # (N,)
# Mèche
images[:, c, :chart_h, wick_x] = np.where(
wick_mask,
color_w[:, None],
images[:, c, :chart_h, wick_x],
)
# Corps
images[:, c, :chart_h, x0:x1] = np.where(
body_mask[:, :, None],
color_w[:, None, None],
images[:, c, :chart_h, x0:x1],
)
# Volume (opacité 60%)
images[:, c, H - vol_h:H, x0:x1] = np.where(
vol_mask[:, :, None],
color_w[:, None, None] * 0.6,
images[:, c, H - vol_h:H, x0:x1],
)
return images
def _render_numpy(self, ohlcv: np.ndarray) -> np.ndarray:
"""
Rend une seule fenêtre OHLCV en image (3, H, W) via pur numpy.
Dessin :
- Fond #0d1117
- Mèche : ligne verticale high→low (1 pixel de large)
- Corps : rectangle open→close (largeur proportionnelle)
- Volume : barre en bas (20% hauteur), opacité 60%
Args:
ohlcv: (N, 5) float32 [open, high, low, close, volume]
Returns:
(3, image_size, image_size) float32 [0, 1]
"""
H = W = self.image_size
n = len(ohlcv)
if n == 0:
return np.zeros((3, H, W), dtype=np.float32)
# --- Fond ---
img = np.empty((3, H, W), dtype=np.float32)
for c in range(3):
img[c] = BG_COLOR[c]
opens = ohlcv[:, 0]
highs = ohlcv[:, 1]
lows = ohlcv[:, 2]
closes = ohlcv[:, 3]
vols = ohlcv[:, 4]
# --- Normalisation des prix ---
price_min = lows.min()
price_max = highs.max()
if price_max <= price_min:
return img # données dégénérées
vol_h = max(1, int(H * VOLUME_RATIO))
chart_h = H - vol_h # hauteur zone prix
def price_to_row(price: float) -> int:
"""Convertit un prix en ligne pixel (0 = haut, chart_h-1 = bas)."""
norm = (price - price_min) / (price_max - price_min)
row = int((1.0 - norm) * (chart_h - 1))
return max(0, min(chart_h - 1, row))
# --- Largeur des bougies ---
candle_w = max(1, W // n)
half_w = max(1, candle_w // 2)
# --- Volume ---
vol_max = vols.max()
for i in range(n):
x_c = int((i + 0.5) * W / n)
x0 = max(0, x_c - half_w)
x1 = min(W, x_c + half_w + 1)
is_bull = closes[i] >= opens[i]
color = GREEN_COLOR if is_bull else RED_COLOR
row_h = price_to_row(highs[i])
row_l = price_to_row(lows[i])
row_o = price_to_row(opens[i])
row_c = price_to_row(closes[i])
body_top = min(row_o, row_c)
body_bot = max(row_o, row_c)
# Mèche (1 pixel de large, centré)
wick_x = min(x_c, W - 1)
for c in range(3):
img[c, row_h: row_l + 1, wick_x] = color[c]
# Corps
if body_bot >= body_top:
for c in range(3):
img[c, body_top: body_bot + 1, x0: x1] = color[c]
# Volume (bas de l'image)
if vol_max > 0:
bar_h = int((vols[i] / vol_max) * vol_h)
if bar_h > 0:
row_v0 = H - bar_h
for c in range(3):
img[c, row_v0: H, x0: x1] = color[c] * 0.6
return img
# -------------------------------------------------------------------------
# Renderer mplfinance (lent, haute qualité, pour affichage uniquement)
# -------------------------------------------------------------------------
def _encode_mplfinance(
self,
df: pd.DataFrame,
seq_len: int,
n_windows: int,
) -> np.ndarray:
"""Encode via mplfinance (lent — ne pas utiliser pour l'entraînement)."""
H = W = self.image_size
images = np.zeros((n_windows, 3, H, W), dtype=np.float32)
import warnings
for i in range(n_windows):
window = df.iloc[i: i + seq_len]
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
images[i] = self._render_with_mplfinance(window)
except Exception as e:
logger.debug(f"Erreur mplfinance fenêtre {i} : {e}")
return images
def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray:
"""
Rendu haute qualité via mplfinance.
Rendu haute qualité via mplfinance (pour affichage / debug).
Style : fond noir (#0d1117), bougies vertes/rouges, volume en bas,
aucun axe ni label ni titre. Image 128×128 RGB.
Raises:
ImportError si mplfinance ou PIL non disponible.
"""
# Style personnalisé : fond noir, bougies colorées TradingView
if not MPLFINANCE_AVAILABLE or not PIL_AVAILABLE:
raise ImportError("mplfinance ou PIL non disponible")
style = mpf.make_mpf_style(
base_mpf_style='nightclouds',
marketcolors=mpf.make_marketcolors(
up='#26a69a', # vert TradingView (hausse)
down='#ef5350', # rouge TradingView (baisse)
up='#26a69a', down='#ef5350',
wick={'up': '#26a69a', 'down': '#ef5350'},
edge={'up': '#26a69a', 'down': '#ef5350'},
volume={'up': '#26a69a', 'down': '#ef5350'},
),
facecolor='#0d1117', # Fond noir GitHub-style
facecolor='#0d1117',
figcolor='#0d1117',
gridcolor='#0d1117',
)
# Rendu en mémoire via BytesIO
buf = io.BytesIO()
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fig, axes = mpf.plot(
df_window,
type='candle',
style=style,
volume=True,
axisoff=True,
tight_layout=True,
returnfig=True,
figsize=(1.28, 1.28),
)
fig, axes = mpf.plot(
df_window,
type='candle',
style=style,
volume=True,
axisoff=True, # Pas d'axes
tight_layout=True,
returnfig=True,
figsize=(1.28, 1.28), # 128 DPI × 1.28 inch = 128 pixels
)
# Suppression de tous les éléments décoratifs
for ax in axes:
ax.set_axis_off()
ax.set_facecolor('#0d1117')
for spine in ax.spines.values():
spine.set_visible(False)
fig.savefig(
buf,
format='png',
dpi=100,
bbox_inches='tight',
pad_inches=0,
facecolor='#0d1117',
)
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight',
pad_inches=0, facecolor='#0d1117')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf).convert('RGB')
pil_img = pil_img.resize(
(self.image_size, self.image_size), Image.LANCZOS
)
pil_img = pil_img.resize((self.image_size, self.image_size), Image.LANCZOS)
# Conversion en array numpy (H, W, 3) → (3, H, W), float32, [0,1]
arr = np.array(pil_img, dtype=np.float32) / 255.0
arr = arr.transpose(2, 0, 1) # HWC → CHW
return arr
def _render_fallback(self, df_window: pd.DataFrame) -> np.ndarray:
"""
Rendu de secours sans mplfinance : encode les données OHLCV
directement en image 2D normalisée.
Chaque colonne de l'image correspond à une bougie :
- Canal 0 (R) : position relative du close dans le range [low, high]
- Canal 1 (G) : amplitude de la bougie (high - low normalisé)
- Canal 2 (B) : volume normalisé
L'image est ensuite redimensionnée à image_size × image_size.
"""
cols = df_window[['open', 'high', 'low', 'close', 'volume']].copy()
n = len(cols)
if n == 0:
return np.zeros((3, self.image_size, self.image_size), dtype=np.float32)
# Normalisation par colonne
highs = cols['high'].values.astype(np.float32)
lows = cols['low'].values.astype(np.float32)
closes = cols['close'].values.astype(np.float32)
opens = cols['open'].values.astype(np.float32)
vols = cols['volume'].values.astype(np.float32)
price_range = highs - lows
price_range = np.where(price_range == 0, 1e-8, price_range)
# Canal R : position du close dans le range de la bougie [0, 1]
close_pos = (closes - lows) / price_range
# Canal G : corps de la bougie (|close - open| / range)
body = np.abs(closes - opens) / price_range
# Canal B : volume normalisé [0, 1]
vol_max = vols.max()
vol_norm = vols / (vol_max if vol_max > 0 else 1.0)
# Construction image : 3 × 1 × n → 3 × image_size × image_size
# On crée une image de height=image_size, width=n puis on redimensionne
img = np.zeros((3, 1, n), dtype=np.float32)
img[0, 0, :] = close_pos
img[1, 0, :] = body
img[2, 0, :] = vol_norm
# Redimensionnement vers (3, image_size, image_size) via répétition
# On étire chaque canal sur les deux dimensions
img_resized = np.zeros(
(3, self.image_size, self.image_size), dtype=np.float32
)
for c in range(3):
# Répétition en hauteur (axe 0 du canal) et interpolation en largeur
channel = img[c, 0, :] # (n,)
# Interpolation 1D vers image_size
x_orig = np.linspace(0, 1, n)
x_new = np.linspace(0, 1, self.image_size)
channel_resized = np.interp(x_new, x_orig, channel).astype(np.float32)
# Étirer sur toute la hauteur
img_resized[c] = np.tile(channel_resized, (self.image_size, 1))
return img_resized
return arr.transpose(2, 0, 1) # HWC → CHW
# -------------------------------------------------------------------------
# Utilitaires
@@ -290,27 +450,18 @@ class CandlestickImageRenderer:
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
"""
Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex.
Args:
df: DataFrame OHLCV brut
Returns:
DataFrame nettoyé avec colonnes open/high/low/close/volume
"""
df = df.copy()
df.columns = [c.lower() for c in df.columns]
# S'assurer que l'index est un DatetimeIndex (requis par mplfinance)
if not isinstance(df.index, pd.DatetimeIndex):
try:
df.index = pd.to_datetime(df.index)
except Exception:
# Créer un index artificiel si la conversion échoue
df.index = pd.date_range(
start='2020-01-01', periods=len(df), freq='1h'
)
# Conserver uniquement les colonnes OHLCV
required = ['open', 'high', 'low', 'close', 'volume']
for col in required:
if col not in df.columns:
@@ -318,5 +469,4 @@ class CandlestickImageRenderer:
df = df[required].dropna(subset=['open', 'high', 'low', 'close'])
df = df.ffill().bfill()
return df

View File

@@ -480,6 +480,9 @@ class CNNImageStrategyModel:
y: Labels encodés (N,) int, valeurs {0, 1, 2}
class_weights: Poids de classes (3,) pour CrossEntropyLoss
"""
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
torch.set_num_threads(4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
class_weights = class_weights.to(device)

View File

@@ -11,10 +11,12 @@ les différents régimes de marché:
Permet d'adapter les stratégies selon le régime actuel.
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pandas as pd
import numpy as np
from datetime import datetime
from datetime import datetime, timedelta
import logging
try:
@@ -24,8 +26,18 @@ except ImportError:
HMMLEARN_AVAILABLE = False
logging.warning("hmmlearn not installed. Install with: pip install hmmlearn")
try:
import joblib
JOBLIB_AVAILABLE = True
except ImportError:
JOBLIB_AVAILABLE = False
logging.warning("joblib not installed. La persistance HMM sera désactivée.")
logger = logging.getLogger(__name__)
# Répertoire de persistance des modèles HMM
MODELS_DIR = Path(__file__).parent.parent.parent / "models" / "hmm"
class RegimeDetector:
"""
@@ -79,6 +91,9 @@ class RegimeDetector:
self.is_fitted = False
self.feature_names = []
# Métadonnées d'entraînement pour la persistance
self._trained_at: Optional[datetime] = None
self._n_samples: int = 0
logger.info(f"RegimeDetector initialized with {n_regimes} regimes")
@@ -115,11 +130,147 @@ class RegimeDetector:
try:
self.model.fit(X)
self.is_fitted = True
logger.info("✅ HMM model fitted successfully")
self._trained_at = datetime.now()
self._n_samples = len(X)
logger.info("Modèle HMM entraîné avec succès")
except Exception as e:
logger.error(f"Error fitting HMM: {e}")
logger.error(f"Erreur lors de l'entraînement HMM : {e}")
raise
@property
def is_trained(self) -> bool:
"""True si le modèle HMM a été ajusté (fit)."""
return self.is_fitted
def needs_retrain(self, max_age_hours: int = 24) -> bool:
"""
Indique si le modèle doit être ré-entraîné.
Un ré-entraînement est nécessaire si :
- Le modèle n'a jamais été entraîné
- La date d'entraînement est inconnue
- Le modèle est plus vieux que max_age_hours
Args:
max_age_hours: Âge maximum du modèle en heures (défaut 24h)
Returns:
True si un ré-entraînement est nécessaire
"""
if not self.is_fitted or self._trained_at is None:
return True
age = datetime.now() - self._trained_at
return age > timedelta(hours=max_age_hours)
def save(self, symbol: str, timeframe: str) -> bool:
"""
Sauvegarde le modèle HMM entraîné sur disque avec joblib.
Le modèle est sauvegardé dans :
models/hmm/{symbol}_{timeframe}.joblib
Les métadonnées (date, n_samples, n_components, labels) sont
stockées dans un fichier JSON compagnon :
models/hmm/{symbol}_{timeframe}_meta.json
Args:
symbol: Symbole de l'instrument (ex : "EURUSD")
timeframe: Unité de temps (ex : "1h")
Returns:
True si la sauvegarde a réussi, False sinon
"""
if not self.is_fitted:
logger.warning("Impossible de sauvegarder : modèle non entraîné")
return False
if not JOBLIB_AVAILABLE:
logger.warning("joblib indisponible — sauvegarde HMM ignorée")
return False
try:
# Créer le répertoire si nécessaire
MODELS_DIR.mkdir(parents=True, exist_ok=True)
base = f"{symbol}_{timeframe}"
model_path = MODELS_DIR / f"{base}.joblib"
meta_path = MODELS_DIR / f"{base}_meta.json"
# Sauvegarder le modèle HMM + feature_names
payload = {
"model": self.model,
"feature_names": self.feature_names,
"n_regimes": self.n_regimes,
"random_state": self.random_state,
}
joblib.dump(payload, model_path)
# Sauvegarder les métadonnées en JSON
meta = {
"trained_at": self._trained_at.isoformat() if self._trained_at else None,
"n_samples": self._n_samples,
"n_components": self.n_regimes,
"regime_labels": self.REGIME_NAMES,
"symbol": symbol,
"timeframe": timeframe,
}
meta_path.write_text(json.dumps(meta, indent=2, default=str))
logger.info(f"Modèle HMM sauvegardé : {model_path}")
return True
except Exception as exc:
logger.error(f"Erreur lors de la sauvegarde du modèle HMM : {exc}")
return False
def load(self, symbol: str, timeframe: str) -> bool:
"""
Charge un modèle HMM depuis le disque.
Args:
symbol: Symbole de l'instrument (ex : "EURUSD")
timeframe: Unité de temps (ex : "1h")
Returns:
True si le chargement a réussi, False sinon
"""
if not JOBLIB_AVAILABLE:
logger.warning("joblib indisponible — chargement HMM impossible")
return False
base = f"{symbol}_{timeframe}"
model_path = MODELS_DIR / f"{base}.joblib"
meta_path = MODELS_DIR / f"{base}_meta.json"
if not model_path.exists():
logger.debug(f"Aucun modèle HMM trouvé pour {symbol}/{timeframe}")
return False
try:
payload = joblib.load(model_path)
self.model = payload["model"]
self.feature_names = payload["feature_names"]
self.n_regimes = payload["n_regimes"]
self.random_state = payload["random_state"]
self.is_fitted = True
# Charger les métadonnées si disponibles
if meta_path.exists():
meta = json.loads(meta_path.read_text())
trained_at_raw = meta.get("trained_at")
self._trained_at = (
datetime.fromisoformat(trained_at_raw) if trained_at_raw else None
)
self._n_samples = meta.get("n_samples", 0)
logger.info(f"Modèle HMM chargé depuis {model_path} (entraîné le {self._trained_at})")
return True
except Exception as exc:
logger.error(f"Erreur lors du chargement du modèle HMM : {exc}")
self.is_fitted = False
return False
def predict_regime(self, data: pd.DataFrame) -> np.ndarray:
"""
Prédit les régimes pour toutes les barres.

24
src/ml/rl/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""
Module RL (Reinforcement Learning) — Agent PPO pour le trading algorithmique.
Ce module implémente un agent PPO (Proximal Policy Optimization) entraîné
par renforcement sur un environnement de trading simulé.
Composants :
TradingEnv — Environnement gymnasium conforme (observation 20 features)
PPOModel — Réseau Actor-Critic MLP 256→128→64 + entraînement PPO
RLStrategyModel — Interface unifiée (identique à MLStrategyModel)
"""
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE
from src.ml.rl.rl_strategy_model import RLStrategyModel, RL_AVAILABLE
__all__ = [
'TradingEnv',
'PPOModel',
'RLStrategyModel',
'RL_AVAILABLE',
'TORCH_AVAILABLE',
'GYM_AVAILABLE',
]

681
src/ml/rl/ppo_model.py Normal file
View File

@@ -0,0 +1,681 @@
"""
PPOModel — Implémentation manuelle de l'algorithme PPO (Proximal Policy Optimization).
Architecture Actor-Critic :
- Réseau partagé : MLP 3 couches (256 → 128 → 64) avec BatchNorm + ReLU
- Tête Actor : couche linéaire → logits (n_actions=3)
- Tête Critic : couche linéaire → valeur scalaire V(s)
Hyperparamètres PPO :
- clip ε=0.2 — clip ratio de probabilité pour éviter les grandes mises à jour
- entropy_coef=0.01 — bonus d'entropie pour exploration
- value_coef=0.5 — coefficient de la loss valeur
- n_steps=2048 — nombre de transitions collectées avant mise à jour
- n_epochs=10 — passes sur les données collectées
- batch_size=64 — taille des mini-batchs
- gamma=0.99 — facteur d'actualisation
- gae_lambda=0.95 — λ pour Generalized Advantage Estimation
Sauvegarde :
- models/rl_strategy/{symbol}_{timeframe}.pt (state_dict + config)
- models/rl_strategy/{symbol}_{timeframe}_meta.json
"""
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
# Répertoire de sauvegarde des modèles RL
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "rl_strategy"
# ──────────────────────────────────────────────────────────────────────────────
# Import conditionnel PyTorch
# ──────────────────────────────────────────────────────────────────────────────
try:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
TORCH_AVAILABLE = True
except ImportError:
torch = None
nn = None
optim = None
Categorical = None
TORCH_AVAILABLE = False
# ──────────────────────────────────────────────────────────────────────────────
# Réseau Actor-Critic
# ──────────────────────────────────────────────────────────────────────────────
if TORCH_AVAILABLE:
class ActorCriticNetwork(nn.Module):
"""
Réseau Actor-Critic partagé pour PPO.
Architecture :
- Tronc commun : Linear(obs_dim → 256) → BN → ReLU
Linear(256 → 128) → BN → ReLU
Linear(128 → 64) → ReLU
- Tête Actor : Linear(64 → n_actions)
- Tête Critic : Linear(64 → 1)
Args:
obs_dim: Dimension de l'espace d'observation
n_actions: Nombre d'actions discrètes
"""
def __init__(self, obs_dim: int = 20, n_actions: int = 3):
super().__init__()
# Tronc commun
self.trunk = nn.Sequential(
nn.Linear(obs_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
)
# Tête Actor : logits pour chaque action
self.actor_head = nn.Linear(64, n_actions)
# Tête Critic : estimation de la valeur d'état V(s)
self.critic_head = nn.Linear(64, 1)
# Initialisation des poids (orthogonale pour stabilité PPO)
self._init_weights()
def _init_weights(self):
"""Initialisation orthogonale recommandée pour les réseaux PPO."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0.0)
# Gain plus faible pour les têtes finales
nn.init.orthogonal_(self.actor_head.weight, gain=0.01)
nn.init.orthogonal_(self.critic_head.weight, gain=1.0)
def forward(
self, x: 'torch.Tensor'
) -> Tuple['torch.Tensor', 'torch.Tensor']:
"""
Calcule les logits actor et la valeur critic.
Args:
x: Tensor d'observations (batch_size, obs_dim)
Returns:
Tuple (logits_actor, value_critic)
logits_actor : (batch_size, n_actions)
value_critic : (batch_size, 1)
"""
features = self.trunk(x)
logits = self.actor_head(features)
value = self.critic_head(features)
return logits, value
def get_action_and_value(
self, x: 'torch.Tensor', action: Optional['torch.Tensor'] = None
) -> Tuple['torch.Tensor', 'torch.Tensor', 'torch.Tensor', 'torch.Tensor']:
"""
Retourne l'action, son log-prob, l'entropie et la valeur.
Args:
x: Observations (batch_size, obs_dim)
action: Si fourni, calcule le log-prob de cet action existante
(pour la phase de mise à jour PPO)
Returns:
Tuple (action, log_prob, entropy, value)
"""
logits, value = self.forward(x)
dist = Categorical(logits=logits)
if action is None:
action = dist.sample()
log_prob = dist.log_prob(action)
entropy = dist.entropy()
return action, log_prob, entropy, value.squeeze(-1)
def get_value(self, x: 'torch.Tensor') -> 'torch.Tensor':
"""Retourne uniquement la valeur critic V(s)."""
_, value = self.forward(x)
return value.squeeze(-1)
class PPOModel:
"""
Agent PPO (Proximal Policy Optimization) pour le trading.
Implémentation manuelle sans stable-baselines3, utilisant PyTorch pur.
Suit le pseudo-code de Schulman et al. (2017) avec GAE.
Méthodes publiques :
train(env, total_timesteps) — Lance l'entraînement
predict(obs) — Retourne l'action et le log-prob
save(symbol, timeframe) — Sauvegarde le modèle
load(symbol, timeframe) — Charge un modèle existant (classmethod)
Args:
obs_dim: Dimension de l'observation (défaut: 20)
n_actions: Nombre d'actions discrètes (défaut: 3)
lr: Taux d'apprentissage Adam (défaut: 3e-4)
n_steps: Transitions par rollout avant mise à jour (défaut: 2048)
n_epochs: Passes d'optimisation par rollout (défaut: 10)
batch_size: Taille des mini-batchs (défaut: 64)
gamma: Facteur d'actualisation (défaut: 0.99)
gae_lambda: Lambda GAE (défaut: 0.95)
clip_eps: Epsilon de clip PPO (défaut: 0.2)
entropy_coef: Coefficient bonus entropie (défaut: 0.01)
value_coef: Coefficient loss valeur (défaut: 0.5)
max_grad_norm: Norme maximale du gradient (défaut: 0.5)
"""
MODELS_DIR = MODELS_DIR
def __init__(
self,
obs_dim: int = 20,
n_actions: int = 3,
lr: float = 3e-4,
n_steps: int = 2048,
n_epochs: int = 10,
batch_size: int = 64,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_eps: float = 0.2,
entropy_coef: float = 0.01,
value_coef: float = 0.5,
max_grad_norm: float = 0.5,
):
self.obs_dim = obs_dim
self.n_actions = n_actions
self.lr = lr
self.n_steps = n_steps
self.n_epochs = n_epochs
self.batch_size = batch_size
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_eps = clip_eps
self.entropy_coef = entropy_coef
self.value_coef = value_coef
self.max_grad_norm = max_grad_norm
self.network: Optional['ActorCriticNetwork'] = None
self.optimizer: Optional['optim.Adam'] = None
self.is_trained: bool = False
self.metadata: Dict = {}
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
if TORCH_AVAILABLE:
self._init_network()
def _init_network(self) -> None:
"""Initialise le réseau Actor-Critic et l'optimiseur Adam."""
self.network = ActorCriticNetwork(self.obs_dim, self.n_actions)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr, eps=1e-5)
# ──────────────────────────────────────────────────────────────────────────
# Entraînement
# ──────────────────────────────────────────────────────────────────────────
def train(
self,
env,
total_timesteps: int = 100_000,
symbol: str = 'EURUSD',
timeframe: str = '1h',
) -> Dict:
"""
Entraîne l'agent PPO sur l'environnement de trading fourni.
Algorithme :
Boucle principale :
1. Collecte de n_steps transitions (rollout)
2. Calcul des avantages GAE
3. n_epochs passes d'optimisation sur mini-batchs mélangés
Fin :
Sauvegarde du modèle et retour des métriques
Args:
env: Environnement TradingEnv
total_timesteps: Nombre total de transitions à collecter
symbol: Symbole pour la sauvegarde (ex: 'EURUSD')
timeframe: Timeframe pour la sauvegarde (ex: '1h')
Returns:
Dict avec métriques : total_timesteps, n_updates, mean_reward,
mean_ep_return, policy_loss, value_loss, entropy_loss
"""
if not TORCH_AVAILABLE:
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
torch.set_num_threads(4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.network.to(device)
logger.info(
f"Début entraînement PPO {symbol}/{timeframe}"
f"total_timesteps={total_timesteps}, n_steps={self.n_steps}, device={device}"
)
# ── Buffers de rollout ──────────────────────────────────────────────
obs_buf = np.zeros((self.n_steps, self.obs_dim), dtype=np.float32)
actions_buf = np.zeros((self.n_steps,), dtype=np.int64)
rewards_buf = np.zeros((self.n_steps,), dtype=np.float32)
dones_buf = np.zeros((self.n_steps,), dtype=np.float32)
values_buf = np.zeros((self.n_steps,), dtype=np.float32)
logprobs_buf = np.zeros((self.n_steps,), dtype=np.float32)
# ── Statistiques d'entraînement ─────────────────────────────────────
all_episode_returns = []
all_policy_losses = []
all_value_losses = []
all_entropy_losses = []
ep_return = 0.0
n_updates = 0
obs, _ = env.reset()
done = False
timestep = 0
rollout_idx = 0
while timestep < total_timesteps:
# ── Phase de collecte (rollout) ───────────────────────────────────
self.network.eval()
with torch.no_grad():
obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
action, log_prob, _, value = self.network.get_action_and_value(obs_t)
action_np = int(action.cpu().item())
log_prob_np = float(log_prob.cpu().item())
value_np = float(value.cpu().item())
next_obs, reward, terminated, truncated, _ = env.step(action_np)
done = terminated or truncated
obs_buf[rollout_idx] = obs
actions_buf[rollout_idx] = action_np
rewards_buf[rollout_idx] = reward
dones_buf[rollout_idx] = float(done)
values_buf[rollout_idx] = value_np
logprobs_buf[rollout_idx] = log_prob_np
ep_return += reward
obs = next_obs
timestep += 1
rollout_idx += 1
if done:
all_episode_returns.append(ep_return)
ep_return = 0.0
obs, _ = env.reset()
# ── Phase de mise à jour PPO ──────────────────────────────────────
if rollout_idx == self.n_steps:
# Calcul de la valeur du dernier état (bootstrap)
with torch.no_grad():
last_obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
last_value = self.network.get_value(last_obs_t).cpu().numpy()
# Calcul des avantages GAE et des retours
advantages, returns = self._compute_gae(
rewards_buf, values_buf, dones_buf, float(last_value)
)
# Conversion en tensors
obs_t = torch.from_numpy(obs_buf).float().to(device)
actions_t = torch.from_numpy(actions_buf).long().to(device)
logprobs_t = torch.from_numpy(logprobs_buf).float().to(device)
advs_t = torch.from_numpy(advantages).float().to(device)
returns_t = torch.from_numpy(returns).float().to(device)
# Normalisation des avantages
advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
# n_epochs passes d'optimisation
n_samples = self.n_steps
policy_losses_ep = []
value_losses_ep = []
entropy_losses_ep = []
self.network.train()
for _ in range(self.n_epochs):
indices = np.random.permutation(n_samples)
for start in range(0, n_samples, self.batch_size):
end = start + self.batch_size
batch = indices[start:end]
if len(batch) < 2:
continue
# Forward pass
_, new_logprob, entropy, new_value = (
self.network.get_action_and_value(
obs_t[batch], actions_t[batch]
)
)
# Ratio de probabilité
log_ratio = new_logprob - logprobs_t[batch]
ratio = torch.exp(log_ratio)
# Clip PPO
adv_batch = advs_t[batch]
pg_loss1 = -adv_batch * ratio
pg_loss2 = -adv_batch * torch.clamp(
ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps
)
policy_loss = torch.max(pg_loss1, pg_loss2).mean()
# Loss valeur (avec clip optionnel)
value_loss = nn.functional.mse_loss(new_value, returns_t[batch])
# Bonus d'entropie (encourage l'exploration)
entropy_loss = -entropy.mean()
# Loss totale
total_loss = (
policy_loss
+ self.value_coef * value_loss
+ self.entropy_coef * entropy_loss
)
# Optimisation
self.optimizer.zero_grad()
total_loss.backward()
nn.utils.clip_grad_norm_(
self.network.parameters(), self.max_grad_norm
)
self.optimizer.step()
policy_losses_ep.append(float(policy_loss.item()))
value_losses_ep.append(float(value_loss.item()))
entropy_losses_ep.append(float(entropy_loss.item()))
all_policy_losses.extend(policy_losses_ep)
all_value_losses.extend(value_losses_ep)
all_entropy_losses.extend(entropy_losses_ep)
n_updates += 1
rollout_idx = 0
if n_updates % 10 == 0:
mean_ret = float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0
logger.info(
f" Timestep {timestep}/{total_timesteps} | "
f"Updates={n_updates} | MeanReturn(20ep)={mean_ret:.4f}"
)
self.network.eval()
self.network.to('cpu')
self.is_trained = True
metrics = {
'symbol': symbol,
'timeframe': timeframe,
'total_timesteps': total_timesteps,
'n_updates': n_updates,
'mean_reward': float(np.mean(all_episode_returns)) if all_episode_returns else 0.0,
'mean_ep_return': float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0,
'n_episodes': len(all_episode_returns),
'policy_loss': float(np.mean(all_policy_losses[-100:])) if all_policy_losses else 0.0,
'value_loss': float(np.mean(all_value_losses[-100:])) if all_value_losses else 0.0,
'entropy_loss': float(np.mean(all_entropy_losses[-100:])) if all_entropy_losses else 0.0,
'trained_at': datetime.utcnow().isoformat(),
'hyperparams': {
'lr': self.lr,
'n_steps': self.n_steps,
'n_epochs': self.n_epochs,
'batch_size': self.batch_size,
'gamma': self.gamma,
'gae_lambda': self.gae_lambda,
'clip_eps': self.clip_eps,
'entropy_coef': self.entropy_coef,
'value_coef': self.value_coef,
},
}
self.metadata = metrics
logger.info(
f"Entraînement PPO terminé. "
f"MeanReturn={metrics['mean_ep_return']:.4f} | "
f"Episodes={metrics['n_episodes']}"
)
return metrics
def predict(self, obs: np.ndarray) -> Tuple[int, float]:
"""
Prédit l'action optimale pour une observation donnée.
En mode inférence (deterministic=True) : argmax des logits.
En mode exploration : sampling de la distribution.
Args:
obs: Vecteur d'observation (obs_dim,) en float32
Returns:
Tuple (action, log_prob) où action ∈ {0, 1, 2}
"""
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
return 0, 0.0
self.network.eval()
with torch.no_grad():
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
action, log_prob, _, _ = self.network.get_action_and_value(obs_t)
return int(action.item()), float(log_prob.item())
def predict_deterministic(self, obs: np.ndarray) -> int:
"""
Prédit l'action déterministe (argmax) sans exploration.
Args:
obs: Vecteur d'observation (obs_dim,) en float32
Returns:
Action déterministe (argmax des logits) ∈ {0, 1, 2}
"""
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
return 0
self.network.eval()
with torch.no_grad():
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
logits, _ = self.network(obs_t)
action = int(logits.argmax(dim=-1).item())
return action
def get_action_probas(self, obs: np.ndarray) -> np.ndarray:
"""
Retourne les probabilités de chaque action via softmax des logits.
Args:
obs: Vecteur d'observation (obs_dim,) en float32
Returns:
np.ndarray de forme (3,) : [P(HOLD), P(LONG), P(SHORT)]
"""
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
return np.array([1.0, 0.0, 0.0], dtype=np.float32)
self.network.eval()
with torch.no_grad():
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
logits, _ = self.network(obs_t)
probas = torch.softmax(logits, dim=-1).squeeze(0).numpy()
return probas.astype(np.float32)
# ──────────────────────────────────────────────────────────────────────────
# Sauvegarde / Chargement
# ──────────────────────────────────────────────────────────────────────────
def save(self, symbol: str = 'EURUSD', timeframe: str = '1h') -> Path:
"""
Sauvegarde le modèle et ses métadonnées sur disque.
Format :
{symbol}_{timeframe}.pt — state_dict + config PyTorch
{symbol}_{timeframe}_meta.json — métadonnées JSON
Args:
symbol: Paire tradée (ex: 'EURUSD')
timeframe: Timeframe (ex: '1h')
Returns:
Path vers le fichier .pt sauvegardé
Raises:
RuntimeError si PyTorch non disponible ou modèle non entraîné
"""
if not TORCH_AVAILABLE:
raise RuntimeError("PyTorch non disponible")
if not self.is_trained or self.network is None:
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
model_id = f"{symbol}_{timeframe}"
model_path = self.MODELS_DIR / f"{model_id}.pt"
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
torch.save(
{
'state_dict': self.network.state_dict(),
'config': {
'obs_dim': self.obs_dim,
'n_actions': self.n_actions,
'lr': self.lr,
'n_steps': self.n_steps,
'n_epochs': self.n_epochs,
'batch_size': self.batch_size,
'gamma': self.gamma,
'gae_lambda': self.gae_lambda,
'clip_eps': self.clip_eps,
'entropy_coef': self.entropy_coef,
'value_coef': self.value_coef,
'max_grad_norm': self.max_grad_norm,
},
'metadata': self.metadata,
},
model_path
)
with open(meta_path, 'w') as f:
json.dump(self.metadata, f, indent=2, default=str)
logger.info(f"Modèle PPO sauvegardé : {model_path}")
return model_path
@classmethod
def load(cls, symbol: str, timeframe: str) -> 'PPOModel':
"""
Charge un modèle PPO existant depuis le disque.
Args:
symbol: Paire (ex: 'EURUSD')
timeframe: Timeframe (ex: '1h')
Returns:
Instance PPOModel prête à prédire
Raises:
RuntimeError si PyTorch non disponible
FileNotFoundError si le modèle n'existe pas
"""
if not TORCH_AVAILABLE:
raise RuntimeError("PyTorch non disponible")
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
if not model_path.exists():
raise FileNotFoundError(f"Modèle PPO non trouvé : {model_path}")
checkpoint = torch.load(model_path, map_location='cpu')
cfg = checkpoint.get('config', {})
instance = cls(
obs_dim = cfg.get('obs_dim', 20),
n_actions = cfg.get('n_actions', 3),
lr = cfg.get('lr', 3e-4),
n_steps = cfg.get('n_steps', 2048),
n_epochs = cfg.get('n_epochs', 10),
batch_size = cfg.get('batch_size', 64),
gamma = cfg.get('gamma', 0.99),
gae_lambda = cfg.get('gae_lambda', 0.95),
clip_eps = cfg.get('clip_eps', 0.2),
entropy_coef = cfg.get('entropy_coef', 0.01),
value_coef = cfg.get('value_coef', 0.5),
max_grad_norm = cfg.get('max_grad_norm', 0.5),
)
instance.network.load_state_dict(checkpoint['state_dict'])
instance.network.eval()
instance.is_trained = True
instance.metadata = checkpoint.get('metadata', {})
logger.info(f"Modèle PPO chargé depuis {model_path}")
return instance
# ──────────────────────────────────────────────────────────────────────────
# Calcul des avantages (GAE)
# ──────────────────────────────────────────────────────────────────────────
def _compute_gae(
self,
rewards: np.ndarray,
values: np.ndarray,
dones: np.ndarray,
last_value: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calcule les avantages généralisés (GAE) et les retours cibles.
Formule GAE :
δt = rt + γ·V(st+1)·(1-dt) - V(st)
Ât = δt + γλ·Ât+1·(1-dt)
Rt = Ât + V(st)
Args:
rewards: Récompenses du rollout (n_steps,)
values: Valeurs estimées par le critic (n_steps,)
dones: Indicateurs de fin d'épisode (n_steps,)
last_value: Valeur bootstrap du dernier état (scalaire)
Returns:
Tuple (advantages, returns) de forme (n_steps,)
"""
n = len(rewards)
advantages = np.zeros(n, dtype=np.float32)
last_gae = 0.0
for t in reversed(range(n)):
if t == n - 1:
next_non_terminal = 1.0 - dones[t]
next_value = last_value
else:
next_non_terminal = 1.0 - dones[t]
next_value = values[t + 1]
delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
advantages[t] = last_gae
returns = advantages + values
return advantages, returns

View File

@@ -0,0 +1,557 @@
"""
RLStrategyModel — Interface unifiée pour l'agent PPO de trading.
Interface identique à MLStrategyModel et CNNImageStrategyModel :
model = RLStrategyModel(symbol='EURUSD', timeframe='1h')
result = model.train(df_ohlcv)
signal = model.predict(df) # {signal, confidence, probas, tradeable}
model.save()
model.load(symbol, timeframe)
model.list_trained_models()
model.get_feature_importance() # retourne [] (RL = boîte noire)
Pipeline d'entraînement :
1. Validation des données OHLCV (minimum 500 barres recommandées)
2. Création de TradingEnv sur les données d'entraînement
3. Entraînement PPO (total_timesteps configurable)
4. Évaluation sur holdout (20% temporel final)
5. Sauvegarde du modèle PPO + métadonnées JSON
"""
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE, N_FEATURES
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE, MODELS_DIR
logger = logging.getLogger(__name__)
# Décodage des actions RL en signaux de trading
# Action 0 (HOLD) → signal 0 (NEUTRAL)
# Action 1 (LONG) → signal 1 (LONG)
# Action 2 (SHORT) → signal -1 (SHORT)
ACTION_TO_SIGNAL = {0: 0, 1: 1, 2: -1}
# Disponibilité globale du module RL
RL_AVAILABLE = TORCH_AVAILABLE and GYM_AVAILABLE
class RLStrategyModel:
"""
Modèle de trading basé sur un agent PPO (Reinforcement Learning).
L'agent apprend directement à maximiser le PnL via interaction avec
un environnement de trading simulé (TradingEnv), sans supervision
explicite de labels.
Avantages par rapport aux modèles supervisés :
- Pas besoin de labels manuels (LONG/SHORT/NEUTRAL)
- Optimise directement l'objectif de trading (PnL)
- Apprend à gérer les positions (durée, timing de sortie)
Args:
symbol: Paire tradée (ex: 'EURUSD')
timeframe: Timeframe (ex: '1h', '15m')
total_timesteps: Nombre de timesteps d'entraînement PPO (défaut: 100_000)
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
min_confidence: Seuil de confiance minimum pour signaux (défaut: 0.50)
train_ratio: Fraction des données pour l'entraînement (défaut: 0.80)
initial_capital: Capital initial pour la simulation (défaut: 10_000)
lr: Taux d'apprentissage Adam (défaut: 3e-4)
n_steps: Transitions par rollout (défaut: 2048)
n_epochs: Passes d'optimisation par rollout (défaut: 10)
batch_size: Taille des mini-batchs (défaut: 64)
"""
MODELS_DIR = MODELS_DIR
def __init__(
self,
symbol: str = 'EURUSD',
timeframe: str = '1h',
total_timesteps: int = 100_000,
sl_atr_mult: float = 1.0,
tp_atr_mult: float = 2.0,
min_confidence: float = 0.50,
train_ratio: float = 0.80,
initial_capital: float = 10_000.0,
lr: float = 3e-4,
n_steps: int = 2048,
n_epochs: int = 10,
batch_size: int = 64,
):
self.symbol = symbol
self.timeframe = timeframe
self.total_timesteps = total_timesteps
self.sl_atr_mult = sl_atr_mult
self.tp_atr_mult = tp_atr_mult
self.min_confidence = min_confidence
self.train_ratio = train_ratio
self.initial_capital = initial_capital
self.lr = lr
self.n_steps = n_steps
self.n_epochs = n_epochs
self.batch_size = batch_size
self.ppo_model: Optional[PPOModel] = None
self.is_trained = False
self.metadata: Dict = {}
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
# ──────────────────────────────────────────────────────────────────────────
# Entraînement
# ──────────────────────────────────────────────────────────────────────────
def train(self, data: pd.DataFrame) -> Dict:
"""
Entraîne l'agent PPO sur les données OHLCV fournies.
Étapes :
1. Validation des données (minimum 200 barres)
2. Split temporel train/holdout (80/20 par défaut)
3. Création du TradingEnv sur le jeu d'entraînement
4. Entraînement PPO (total_timesteps)
5. Évaluation sur le jeu de holdout
6. Sauvegarde automatique du modèle
Args:
data: DataFrame OHLCV (colonnes : open/high/low/close/volume)
Minimum 200 barres, idéalement 2000+ pour un bon apprentissage.
Returns:
Dict avec métriques :
- total_timesteps, n_episodes, mean_ep_return (entraînement)
- eval_total_pnl, eval_return_pct, eval_sharpe (évaluation holdout)
- n_samples, trained_at, error (si échec)
"""
if not TORCH_AVAILABLE:
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
if not GYM_AVAILABLE:
return {'error': 'gymnasium non disponible — installer gymnasium>=0.26'}
logger.info(
f"Début entraînement RLStrategyModel pour "
f"{self.symbol}/{self.timeframe} — total_timesteps={self.total_timesteps}"
)
# 1. Validation et normalisation des données
df = data.copy()
df.columns = [c.lower() for c in df.columns]
required_cols = {'open', 'high', 'low', 'close', 'volume'}
missing = required_cols - set(df.columns)
if missing:
return {'error': f'Colonnes manquantes : {missing}'}
df = df.dropna(subset=list(required_cols))
if len(df) < 200:
return {
'error': f'Données insuffisantes : {len(df)} barres (minimum 200)'
}
# 2. Split temporel train / holdout (respecte l'ordre chronologique)
n_train = int(len(df) * self.train_ratio)
df_train = df.iloc[:n_train].reset_index(drop=True)
df_holdout = df.iloc[n_train:].reset_index(drop=True)
logger.info(
f" Split : train={len(df_train)} barres, holdout={len(df_holdout)} barres"
)
if len(df_train) < 100:
return {'error': f'Données d\'entraînement insuffisantes : {len(df_train)} barres'}
# 3. Création de l'environnement d'entraînement
train_env = TradingEnv(
df = df_train,
sl_atr_mult = self.sl_atr_mult,
tp_atr_mult = self.tp_atr_mult,
initial_capital = self.initial_capital,
)
# 4. Initialisation et entraînement PPO
self.ppo_model = PPOModel(
obs_dim = N_FEATURES,
n_actions = 3,
lr = self.lr,
n_steps = min(self.n_steps, len(df_train) - 30), # Sécurité
n_epochs = self.n_epochs,
batch_size = self.batch_size,
)
train_metrics = self.ppo_model.train(
env = train_env,
total_timesteps = self.total_timesteps,
symbol = self.symbol,
timeframe = self.timeframe,
)
if 'error' in train_metrics:
return train_metrics
self.is_trained = True
# 5. Évaluation sur holdout
eval_metrics = {}
if len(df_holdout) >= 50:
eval_metrics = self._evaluate_on_holdout(df_holdout)
logger.info(
f" Évaluation holdout : "
f"PnL={eval_metrics.get('eval_total_pnl', 0):.4f}, "
f"Return={eval_metrics.get('eval_return_pct', 0):.2%}, "
f"Sharpe≈{eval_metrics.get('eval_sharpe', 0):.2f}"
)
# 6. Assemblage des métadonnées
self.metadata = {
'symbol': self.symbol,
'timeframe': self.timeframe,
'trained_at': datetime.utcnow().isoformat(),
'n_samples': len(df),
'n_train': len(df_train),
'n_holdout': len(df_holdout),
'total_timesteps': train_metrics.get('total_timesteps', self.total_timesteps),
'n_episodes': train_metrics.get('n_episodes', 0),
'n_updates': train_metrics.get('n_updates', 0),
'mean_ep_return': train_metrics.get('mean_ep_return', 0.0),
'policy_loss': train_metrics.get('policy_loss', 0.0),
'value_loss': train_metrics.get('value_loss', 0.0),
'entropy_loss': train_metrics.get('entropy_loss', 0.0),
'sl_atr_mult': self.sl_atr_mult,
'tp_atr_mult': self.tp_atr_mult,
'min_confidence': self.min_confidence,
'hyperparams': train_metrics.get('hyperparams', {}),
**eval_metrics,
}
# 7. Sauvegarde automatique
self.save()
logger.info(
f"Entraînement RL terminé pour {self.symbol}/{self.timeframe}"
f"N_episodes={self.metadata['n_episodes']}, "
f"MeanReturn={self.metadata['mean_ep_return']:.4f}"
)
return self.metadata
# ──────────────────────────────────────────────────────────────────────────
# Prédiction
# ──────────────────────────────────────────────────────────────────────────
def predict(self, data: pd.DataFrame) -> Dict:
"""
Prédit le signal de trading pour la dernière barre disponible.
L'agent évalue les 20 dernières barres (fenêtre lookback) et retourne
son action déterministe (HOLD/LONG/SHORT) avec les probabilités associées.
Args:
data: DataFrame OHLCV récent (minimum 25 barres pour les indicateurs)
Returns:
Dict : {
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL/HOLD),
'confidence': float [0..1] — max(P_LONG, P_SHORT),
'probas': {'hold': float, 'long': float, 'short': float},
'tradeable': bool — confidence >= min_confidence et signal != 0,
}
"""
if not TORCH_AVAILABLE:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': 'PyTorch non disponible'
}
if not self.is_trained or self.ppo_model is None:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': 'Modèle non entraîné — appeler train() d\'abord'
}
df = data.copy()
df.columns = [c.lower() for c in df.columns]
if len(df) < 25:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': f'Pas assez de données : {len(df)} barres (minimum 25)'
}
try:
# Créer un environnement temporaire pour obtenir l'observation
temp_env = TradingEnv(
df = df,
sl_atr_mult = self.sl_atr_mult,
tp_atr_mult = self.tp_atr_mult,
initial_capital = self.initial_capital,
)
# Positionner l'environnement à la dernière barre
temp_env._current_step = len(df) - 1
obs = temp_env._get_observation()
except Exception as e:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': f'Erreur construction observation : {e}'
}
# Prédiction déterministe (argmax des logits)
action = self.ppo_model.predict_deterministic(obs)
probas = self.ppo_model.get_action_probas(obs) # [P(HOLD), P(LONG), P(SHORT)]
signal = ACTION_TO_SIGNAL.get(action, 0)
confidence = float(max(probas[1], probas[2])) # max(P_LONG, P_SHORT)
return {
'signal': signal,
'confidence': confidence,
'probas': {
'hold': float(probas[0]),
'long': float(probas[1]),
'short': float(probas[2]),
},
'tradeable': confidence >= self.min_confidence and signal != 0,
}
# ──────────────────────────────────────────────────────────────────────────
# Sauvegarde / Chargement
# ──────────────────────────────────────────────────────────────────────────
def save(self) -> Path:
"""
Sauvegarde le modèle PPO et les métadonnées sur disque.
Returns:
Path vers le fichier .pt sauvegardé
Raises:
RuntimeError si le modèle n'est pas entraîné
"""
if not self.is_trained or self.ppo_model is None:
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
# Sauvegarde du réseau PPO
model_path = self.ppo_model.save(self.symbol, self.timeframe)
# Sauvegarde des métadonnées complètes (incluant les params RLStrategyModel)
meta = self.metadata.copy()
meta.update({
'rl_config': {
'sl_atr_mult': self.sl_atr_mult,
'tp_atr_mult': self.tp_atr_mult,
'min_confidence': self.min_confidence,
'train_ratio': self.train_ratio,
'initial_capital': self.initial_capital,
'total_timesteps': self.total_timesteps,
}
})
meta_path = self.MODELS_DIR / f"{self.symbol}_{self.timeframe}_meta.json"
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2, default=str)
return model_path
@classmethod
def load(cls, symbol: str, timeframe: str) -> 'RLStrategyModel':
"""
Charge un modèle RL existant depuis le disque.
Args:
symbol: Paire (ex: 'EURUSD')
timeframe: Timeframe (ex: '1h')
Returns:
Instance RLStrategyModel prête à prédire
Raises:
RuntimeError si PyTorch ou gymnasium non disponible
FileNotFoundError si le modèle n'existe pas
"""
if not TORCH_AVAILABLE:
raise RuntimeError("PyTorch non disponible")
meta_path = MODELS_DIR / f"{symbol}_{timeframe}_meta.json"
rl_config = {}
if meta_path.exists():
try:
with open(meta_path) as f:
saved_meta = json.load(f)
rl_config = saved_meta.get('rl_config', {})
except Exception:
pass
instance = cls(
symbol = symbol,
timeframe = timeframe,
sl_atr_mult = rl_config.get('sl_atr_mult', 1.0),
tp_atr_mult = rl_config.get('tp_atr_mult', 2.0),
min_confidence = rl_config.get('min_confidence', 0.50),
train_ratio = rl_config.get('train_ratio', 0.80),
initial_capital = rl_config.get('initial_capital', 10_000.0),
total_timesteps = rl_config.get('total_timesteps', 100_000),
)
# Chargement du modèle PPO
instance.ppo_model = PPOModel.load(symbol, timeframe)
instance.is_trained = True
if meta_path.exists():
try:
with open(meta_path) as f:
instance.metadata = json.load(f)
except Exception:
pass
logger.info(f"Modèle RL chargé depuis {MODELS_DIR}/{symbol}_{timeframe}.pt")
return instance
@staticmethod
def list_trained_models() -> List[Dict]:
"""
Retourne la liste des modèles RL entraînés disponibles.
Returns:
Liste de dicts avec symbol, timeframe, trained_at, n_samples,
mean_ep_return, eval_return_pct
"""
if not MODELS_DIR.exists():
return []
models = []
for f in MODELS_DIR.glob("*_meta.json"):
try:
with open(f) as fp:
meta = json.load(fp)
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
parts = stem.split('_', 1)
models.append({
'symbol': meta.get('symbol', parts[0] if parts else '?'),
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
'trained_at': meta.get('trained_at', '?'),
'n_samples': meta.get('n_samples', 0),
'total_timesteps': meta.get('total_timesteps', 0),
'n_episodes': meta.get('n_episodes', 0),
'mean_ep_return': meta.get('mean_ep_return', 0.0),
'eval_return_pct': meta.get('eval_return_pct', 0.0),
'eval_sharpe': meta.get('eval_sharpe', 0.0),
})
except Exception:
pass
return models
def get_feature_importance(self) -> List[Dict]:
"""
Retourne l'importance des features.
Note : Les agents RL (PPO) sont des boîtes noires — pas d'importance
de feature interprétable. Retourne les noms des 20 features pour
information, sans score numérique.
Returns:
Liste de dicts avec feature_name et description
"""
feature_names = [
{'rank': 1, 'feature': 'body_ratio', 'description': '(close-open)/ATR — corps de la bougie'},
{'rank': 2, 'feature': 'amplitude', 'description': '(high-low)/ATR — amplitude de la bougie'},
{'rank': 3, 'feature': 'momentum_1', 'description': '(close-close[-1])/ATR — momentum 1 barre'},
{'rank': 4, 'feature': 'momentum_5', 'description': '(close-close[-5])/ATR — momentum 5 barres'},
{'rank': 5, 'feature': 'rsi_norm', 'description': 'RSI(14)/100 — RSI normalisé'},
{'rank': 6, 'feature': 'bb_position', 'description': '(close-BB_upper)/ATR — position vs Bollinger'},
{'rank': 7, 'feature': 'bb_width', 'description': '(BB_upper-BB_lower)/ATR — largeur bandes'},
{'rank': 8, 'feature': 'macd_norm', 'description': 'MACD/ATR — MACD normalisé'},
{'rank': 9, 'feature': 'macd_signal_norm', 'description': 'Signal MACD/ATR — signal MACD normalisé'},
{'rank': 10, 'feature': 'vol_ratio', 'description': 'Volume/MA20(Volume) — ratio volume'},
{'rank': 11, 'feature': 'position', 'description': 'Position courante (-1=short, 0=flat, 1=long)'},
{'rank': 12, 'feature': 'unrealized_pnl', 'description': 'PnL_non_réalisé/ATR — gain/perte courant'},
{'rank': 13, 'feature': 'bars_in_trade', 'description': 'Durée du trade normalisée'},
{'rank': 14, 'feature': 'hour', 'description': 'Heure de la journée normalisée'},
{'rank': 15, 'feature': 'day_of_week', 'description': 'Jour de semaine normalisé'},
{'rank': 16, 'feature': 'ema9_dist', 'description': '(close-EMA9)/ATR — distance EMA9'},
{'rank': 17, 'feature': 'ema21_dist', 'description': '(close-EMA21)/ATR — distance EMA21'},
{'rank': 18, 'feature': 'ema9_ema21_cross', 'description': '(EMA9-EMA21)/ATR — croisement EMA court'},
{'rank': 19, 'feature': 'ema50_dist', 'description': '(close-EMA50)/ATR — distance EMA50'},
{'rank': 20, 'feature': 'ema200_dist', 'description': '(close-EMA200)/ATR — distance EMA200'},
]
return feature_names
# ──────────────────────────────────────────────────────────────────────────
# Évaluation interne
# ──────────────────────────────────────────────────────────────────────────
def _evaluate_on_holdout(self, df_holdout: pd.DataFrame) -> Dict:
"""
Évalue l'agent entraîné sur le jeu de holdout.
Lance un épisode complet sur les données de holdout en mode déterministe
(argmax des logits, pas de sampling). Retourne les métriques de performance.
Args:
df_holdout: DataFrame OHLCV du jeu de holdout
Returns:
Dict avec eval_total_pnl, eval_return_pct, eval_sharpe,
eval_n_trades, eval_win_rate
"""
try:
eval_env = TradingEnv(
df = df_holdout,
sl_atr_mult = self.sl_atr_mult,
tp_atr_mult = self.tp_atr_mult,
initial_capital = self.initial_capital,
)
obs, _ = eval_env.reset()
done = False
rewards_hist = []
n_trades = 0
win_trades = 0
while not done:
action = self.ppo_model.predict_deterministic(obs)
obs, reward, terminated, truncated, info = eval_env.step(action)
done = terminated or truncated
if abs(reward) > 0.001: # Fermeture de position
n_trades += 1
if reward > 0:
win_trades += 1
rewards_hist.append(reward)
stats = eval_env.get_episode_stats()
win_rate = win_trades / max(n_trades, 1)
rewards_arr = np.array(rewards_hist)
sharpe = (
float(rewards_arr.mean() / (rewards_arr.std() + 1e-8) * np.sqrt(252))
if len(rewards_arr) > 1 else 0.0
)
return {
'eval_total_pnl': float(stats.get('total_pnl', 0.0)),
'eval_return_pct': float(stats.get('return_pct', 0.0)),
'eval_sharpe': float(sharpe),
'eval_n_trades': n_trades,
'eval_win_rate': float(win_rate),
'eval_n_steps': int(stats.get('n_steps', 0)),
}
except Exception as e:
logger.warning(f"Évaluation holdout échouée : {e}")
return {
'eval_total_pnl': 0.0,
'eval_return_pct': 0.0,
'eval_sharpe': 0.0,
'eval_n_trades': 0,
'eval_win_rate': 0.0,
}

524
src/ml/rl/trading_env.py Normal file
View File

@@ -0,0 +1,524 @@
"""
TradingEnv — Environnement Gymnasium pour l'agent RL de trading.
Conforme à l'API gymnasium (reset/step) et adapté aux marchés forex.
Espace d'observation : 20 features normalisées (fenêtre glissante de 20 barres)
Espace d'action : Discrete(3) → {0=HOLD, 1=LONG, 2=SHORT}
Récompense : PnL réalisé + pénalité drawdown + bonus Sharpe
L'environnement gère :
- Les transitions de position (flat → long, flat → short, inversions)
- Le stop-loss / take-profit automatiques basés sur l'ATR
- La normalisation des observations par l'ATR
- La fenêtre glissante de 20 barres (lookback)
"""
import logging
from typing import Optional, Tuple
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
# Imports conditionnels gymnasium / gym
try:
import gymnasium as gym
from gymnasium import spaces
GYM_AVAILABLE = True
GYM_MODULE = 'gymnasium'
except ImportError:
try:
import gym
from gym import spaces
GYM_AVAILABLE = True
GYM_MODULE = 'gym'
except ImportError:
gym = None
spaces = None
GYM_AVAILABLE = False
GYM_MODULE = None
# ──────────────────────────────────────────────────────────────────────────────
# Actions
# ──────────────────────────────────────────────────────────────────────────────
ACTION_HOLD = 0
ACTION_LONG = 1
ACTION_SHORT = 2
# Nombre de features dans le vecteur d'observation
N_FEATURES = 20
# Fenêtre lookback (nombre de barres passées incluses dans l'observation)
LOOKBACK = 20
class TradingEnv:
"""
Environnement de trading conforme à l'API gymnasium pour l'agent PPO.
Chaque step() représente la décision de trading à la fermeture d'une bougie.
L'agent reçoit 20 features normalisées décrivant le contexte de marché et
sa position courante, puis choisit HOLD / LONG / SHORT.
Gestion des positions :
- Une seule position à la fois (pas de pyramiding)
- Inversion directe possible (SHORT → LONG sans passer par FLAT)
- SL/TP ATR-based (sl_atr_mult×ATR et tp_atr_mult×ATR)
Récompense :
- PnL réalisé à la clôture de position (en multiple d'ATR)
- Pénalité proportionnelle au drawdown courant
- Pas de bonus pour maintenir une position (évite le surtrading)
Args:
df: DataFrame OHLCV (colonnes minuscules : open/high/low/close/volume)
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
atr_period: Période de calcul de l'ATR (défaut: 14)
initial_capital: Capital initial (pour calcul drawdown) (défaut: 10000)
drawdown_penalty: Coefficient de pénalité drawdown (défaut: 0.1)
"""
metadata = {'render_modes': []}
def __init__(
self,
df: pd.DataFrame,
sl_atr_mult: float = 1.0,
tp_atr_mult: float = 2.0,
atr_period: int = 14,
initial_capital: float = 10_000.0,
drawdown_penalty: float = 0.1,
):
if not GYM_AVAILABLE:
raise RuntimeError(
"gymnasium ou gym requis — installer gymnasium>=0.26 ou gym>=0.21"
)
self.df = df.copy().reset_index(drop=True)
self.df.columns = [c.lower() for c in self.df.columns]
self.sl_atr_mult = sl_atr_mult
self.tp_atr_mult = tp_atr_mult
self.atr_period = atr_period
self.initial_capital = initial_capital
self.drawdown_penalty = drawdown_penalty
# Calcul de l'ATR sur toute la série (optimisation : évite le recalcul dans step)
self._atr_series = self._compute_atr_series()
# Calcul des EMAs sur toute la série
self._ema9 = self.df['close'].ewm(span=9, adjust=False).mean()
self._ema21 = self.df['close'].ewm(span=21, adjust=False).mean()
self._ema50 = self.df['close'].ewm(span=50, adjust=False).mean()
self._ema200 = self.df['close'].ewm(span=200, adjust=False).mean()
# Calcul des Bandes de Bollinger (20, 2σ)
bb_ma = self.df['close'].rolling(20).mean()
bb_std = self.df['close'].rolling(20).std()
self._bb_upper = bb_ma + 2 * bb_std
self._bb_lower = bb_ma - 2 * bb_std
# Calcul du RSI(14)
self._rsi = self._compute_rsi(period=14)
# Calcul du MACD (12, 26, 9)
ema12 = self.df['close'].ewm(span=12, adjust=False).mean()
ema26 = self.df['close'].ewm(span=26, adjust=False).mean()
self._macd = ema12 - ema26
self._macd_sig = self._macd.ewm(span=9, adjust=False).mean()
# Volume MA(20) pour normalisation
self._vol_ma20 = self.df['volume'].rolling(20).mean().fillna(
self.df['volume'].mean()
)
# Espaces gymnasium
self.observation_space = spaces.Box(
low = -10.0,
high = 10.0,
shape = (N_FEATURES,),
dtype = np.float32,
)
self.action_space = spaces.Discrete(3)
# État interne (initialisé dans reset())
self._current_step = LOOKBACK
self._position = 0 # -1=short, 0=flat, 1=long
self._entry_price = 0.0
self._entry_atr = 0.0
self._bars_in_trade = 0
self._capital = initial_capital
self._peak_capital = initial_capital
self._total_pnl = 0.0
self._pnl_history = [] # Pour calcul Sharpe en fin d'épisode
self._done = False
# ──────────────────────────────────────────────────────────────────────────
# Interface gymnasium
# ──────────────────────────────────────────────────────────────────────────
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
"""
Réinitialise l'environnement au début d'un épisode.
Args:
seed: Graine aléatoire (ignorée ici, données déterministes)
options: Options additionnelles (non utilisées)
Returns:
Tuple (observation, info) conforme gymnasium
"""
if seed is not None:
np.random.seed(seed)
self._current_step = LOOKBACK
self._position = 0
self._entry_price = 0.0
self._entry_atr = 0.0
self._bars_in_trade = 0
self._capital = self.initial_capital
self._peak_capital = self.initial_capital
self._total_pnl = 0.0
self._pnl_history = []
self._done = False
obs = self._get_observation()
info = {}
return obs, info
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
"""
Exécute une action et retourne la transition.
Args:
action: 0=HOLD, 1=LONG, 2=SHORT
Returns:
Tuple (observation, reward, terminated, truncated, info) conforme gymnasium
"""
if self._done:
obs, info = self.reset()
return obs, 0.0, True, False, info
reward = 0.0
info = {}
current_close = float(self.df['close'].iloc[self._current_step])
current_atr = float(self._atr_series.iloc[self._current_step])
if current_atr <= 0:
current_atr = float(self.df['close'].iloc[self._current_step]) * 0.001
# ── Vérification SL/TP avant d'appliquer la nouvelle action ──────────
if self._position != 0:
reward += self._check_sl_tp(current_close, current_atr)
# ── Application de la nouvelle action ─────────────────────────────────
desired_position = self._action_to_position(action)
if desired_position != self._position:
# Fermeture de la position courante (si ouverte)
if self._position != 0:
close_reward = self._close_position(current_close, current_atr)
reward += close_reward
# Ouverture d'une nouvelle position (si pas HOLD)
if desired_position != 0:
self._open_position(desired_position, current_close, current_atr)
else:
# Même position : incrémenter le compteur de barres
if self._position != 0:
self._bars_in_trade += 1
# ── Pénalité drawdown ─────────────────────────────────────────────────
if self._capital > self._peak_capital:
self._peak_capital = self._capital
drawdown = (self._peak_capital - self._capital) / max(self._peak_capital, 1.0)
if drawdown > 0:
reward -= self.drawdown_penalty * drawdown
# ── Avance d'une barre ────────────────────────────────────────────────
self._current_step += 1
terminated = self._current_step >= len(self.df) - 1
truncated = False
if terminated:
# Fermeture forcée de la position à la fin de l'épisode
if self._position != 0:
final_close = float(self.df['close'].iloc[-1])
final_atr = float(self._atr_series.iloc[-1])
reward += self._close_position(final_close, final_atr)
self._done = True
obs = self._get_observation()
# Sauvegarde de la récompense pour calcul Sharpe
self._pnl_history.append(reward)
info = {
'position': self._position,
'capital': self._capital,
'drawdown': drawdown,
'bars_in_trade': self._bars_in_trade,
'step': self._current_step,
}
return obs, float(reward), terminated, truncated, info
def render(self):
"""Affichage minimal de l'état courant."""
pos_str = {0: 'FLAT', 1: 'LONG', -1: 'SHORT'}.get(self._position, '?')
logger.debug(
f"Step={self._current_step} | Pos={pos_str} | "
f"Capital={self._capital:.2f} | PnL={self._total_pnl:.4f}"
)
# ──────────────────────────────────────────────────────────────────────────
# Observation
# ──────────────────────────────────────────────────────────────────────────
def _get_observation(self) -> np.ndarray:
"""
Construit le vecteur d'observation de 20 features normalisées.
Toutes les features de prix sont normalisées par l'ATR courant pour
rendre l'observation invariante à l'échelle du prix.
Returns:
np.ndarray de forme (20,) avec dtype float32
"""
i = self._current_step
# Sécurité : ne pas dépasser les bornes
i = max(LOOKBACK, min(i, len(self.df) - 1))
close = float(self.df['close'].iloc[i])
open_ = float(self.df['open'].iloc[i])
high = float(self.df['high'].iloc[i])
low = float(self.df['low'].iloc[i])
vol = float(self.df['volume'].iloc[i])
atr = float(self._atr_series.iloc[i])
if atr <= 0:
atr = close * 0.001
# Close i-1 et i-5 pour le momentum
close_1 = float(self.df['close'].iloc[max(0, i - 1)])
close_5 = float(self.df['close'].iloc[max(0, i - 5)])
rsi = float(self._rsi.iloc[i]) if not np.isnan(self._rsi.iloc[i]) else 50.0
bb_upper = float(self._bb_upper.iloc[i]) if not np.isnan(self._bb_upper.iloc[i]) else close
bb_lower = float(self._bb_lower.iloc[i]) if not np.isnan(self._bb_lower.iloc[i]) else close
macd = float(self._macd.iloc[i]) if not np.isnan(self._macd.iloc[i]) else 0.0
macd_signal = float(self._macd_sig.iloc[i]) if not np.isnan(self._macd_sig.iloc[i]) else 0.0
vol_ma20 = float(self._vol_ma20.iloc[i])
ema9 = float(self._ema9.iloc[i])
ema21 = float(self._ema21.iloc[i])
ema50 = float(self._ema50.iloc[i])
ema200 = float(self._ema200.iloc[i])
vol_ratio = vol / max(vol_ma20, 1.0)
# PnL non-réalisé courant
if self._position != 0 and self._entry_price > 0:
unrealized = self._position * (close - self._entry_price)
else:
unrealized = 0.0
features = np.array([
# Prix relatifs normalisés par ATR
(close - open_) / atr, # 0 : corps de la bougie
(high - low) / atr, # 1 : amplitude bougie (volatilité relative)
(close - close_1) / atr, # 2 : momentum 1 barre
(close - close_5) / atr, # 3 : momentum 5 barres
# Indicateurs techniques normalisés
rsi / 100.0, # 4 : RSI [0..1]
(close - bb_upper) / atr, # 5 : position vs BB haute
(bb_upper - bb_lower) / atr, # 6 : largeur des bandes de Bollinger
macd / atr, # 7 : MACD normalisé
macd_signal / atr, # 8 : signal MACD normalisé
# Volume
np.clip(vol_ratio, 0.0, 5.0), # 9 : ratio volume / MA20 (plafonné à 5)
# Position et état du trade
float(self._position), # 10: position courante (-1, 0, 1)
unrealized / atr, # 11: PnL non-réalisé en ATR
min(self._bars_in_trade / 50.0, 1.0), # 12: durée du trade (normalisée)
# Temporel (si index datetime)
self._get_hour(i), # 13: heure normalisée [0..1]
self._get_dow(i), # 14: jour de semaine [0..1]
# Moyennes mobiles (distance close - EMA, normalisée par ATR)
(close - ema9) / atr, # 15: distance EMA9
(close - ema21) / atr, # 16: distance EMA21
(ema9 - ema21) / atr, # 17: croisement EMA9/21
(close - ema50) / atr, # 18: distance EMA50
(close - ema200) / atr, # 19: distance EMA200
], dtype=np.float32)
# Clip pour éviter les valeurs extrêmes (protection robustesse)
features = np.clip(features, -10.0, 10.0)
# Remplacement des NaN résiduels
features = np.nan_to_num(features, nan=0.0, posinf=10.0, neginf=-10.0)
return features
# ──────────────────────────────────────────────────────────────────────────
# Gestion des positions
# ──────────────────────────────────────────────────────────────────────────
def _open_position(self, direction: int, price: float, atr: float) -> None:
"""
Ouvre une position dans la direction donnée.
Args:
direction: 1=LONG, -1=SHORT
price: Prix d'entrée (close de la barre courante)
atr: ATR courant pour le calcul SL/TP
"""
self._position = direction
self._entry_price = price
self._entry_atr = atr
self._bars_in_trade = 0
def _close_position(self, price: float, atr: float) -> float:
"""
Ferme la position courante et calcule la récompense.
La récompense est le PnL en multiple d'ATR (normalisé et sans unité).
Cette normalisation rend la récompense comparable entre différents
symboles et périodes de volatilité.
Args:
price: Prix de sortie
atr: ATR courant (pour normalisation)
Returns:
Récompense (PnL normalisé par ATR)
"""
if self._position == 0 or self._entry_price <= 0:
self._position = 0
self._bars_in_trade = 0
return 0.0
raw_pnl = self._position * (price - self._entry_price)
# Normalisation par l'ATR d'entrée pour une récompense sans unité
entry_atr = max(self._entry_atr, price * 0.0001)
reward = raw_pnl / entry_atr
# Mise à jour du capital (simulation simplifiée avec 1 lot fixe)
self._capital += raw_pnl
self._total_pnl += raw_pnl
self._position = 0
self._entry_price = 0.0
self._bars_in_trade = 0
return float(reward)
def _check_sl_tp(self, current_close: float, current_atr: float) -> float:
"""
Vérifie si le SL ou TP est atteint pour la position courante.
Utilise l'ATR d'entrée pour les niveaux SL/TP (stabilité).
Si le SL ou TP est touché, la position est fermée.
Args:
current_close: Prix de clôture de la barre courante
current_atr: ATR courant
Returns:
Récompense si fermeture forcée, 0.0 sinon
"""
if self._position == 0 or self._entry_price <= 0:
return 0.0
entry_atr = max(self._entry_atr, self._entry_price * 0.0001)
sl_dist = self.sl_atr_mult * entry_atr
tp_dist = self.tp_atr_mult * entry_atr
if self._position == 1: # LONG
sl_level = self._entry_price - sl_dist
tp_level = self._entry_price + tp_dist
if current_close <= sl_level or current_close >= tp_level:
return self._close_position(current_close, current_atr)
elif self._position == -1: # SHORT
sl_level = self._entry_price + sl_dist
tp_level = self._entry_price - tp_dist
if current_close >= sl_level or current_close <= tp_level:
return self._close_position(current_close, current_atr)
return 0.0
# ──────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────
@staticmethod
def _action_to_position(action: int) -> int:
"""Convertit l'action discrète en direction de position."""
return {ACTION_HOLD: 0, ACTION_LONG: 1, ACTION_SHORT: -1}.get(action, 0)
def _get_hour(self, i: int) -> float:
"""Retourne l'heure normalisée [0..1] depuis l'index, ou 0.5 si non datetime."""
try:
idx = self.df.index[i]
if hasattr(idx, 'hour'):
return float(idx.hour) / 24.0
except Exception:
pass
return 0.5
def _get_dow(self, i: int) -> float:
"""Retourne le jour de semaine normalisé [0..1] depuis l'index, ou 0.5 si non datetime."""
try:
idx = self.df.index[i]
if hasattr(idx, 'dayofweek'):
return float(idx.dayofweek) / 5.0
except Exception:
pass
return 0.5
def _compute_atr_series(self) -> pd.Series:
"""Calcule l'ATR(14) sur toute la série OHLCV."""
h = self.df['high']
l = self.df['low']
pc = self.df['close'].shift(1)
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
atr = tr.ewm(span=self.atr_period, adjust=False).mean()
# Remplir les premières valeurs NaN par la plage high-low
atr = atr.fillna(h - l)
return atr
def _compute_rsi(self, period: int = 14) -> pd.Series:
"""Calcule le RSI sur toute la série de closes."""
delta = self.df['close'].diff()
gain = delta.clip(lower=0).ewm(span=period, adjust=False).mean()
loss = (-delta.clip(upper=0)).ewm(span=period, adjust=False).mean()
rs = gain / loss.replace(0, np.nan)
rsi = 100.0 - (100.0 / (1.0 + rs))
return rsi.fillna(50.0)
def get_episode_stats(self) -> dict:
"""
Retourne les statistiques de l'épisode courant.
Utile pour le logging et l'évaluation du modèle PPO.
Returns:
Dict avec total_pnl, n_steps, capital, Sharpe approximatif
"""
rewards = np.array(self._pnl_history)
if len(rewards) > 1 and rewards.std() > 0:
sharpe = float(rewards.mean() / rewards.std() * np.sqrt(252))
else:
sharpe = 0.0
return {
'total_pnl': self._total_pnl,
'capital': self._capital,
'return_pct': (self._capital - self.initial_capital) / self.initial_capital,
'n_steps': self._current_step,
'sharpe_approx': sharpe,
}

View File

@@ -0,0 +1,10 @@
"""
Module RL-Driven Strategy — Stratégie de trading pilotée par un agent PPO.
L'agent PPO apprend à trader par renforcement : il maximise directement
le PnL sans supervision explicite (pas de labels LONG/SHORT/NEUTRAL).
"""
from src.strategies.rl_driven.rl_strategy import RLDrivenStrategy, RL_AVAILABLE
__all__ = ['RLDrivenStrategy', 'RL_AVAILABLE']

View File

@@ -0,0 +1,259 @@
"""
RL-Driven Strategy — Stratégie pilotée par un agent de Reinforcement Learning (PPO).
Contrairement aux stratégies supervisées (XGBoost, CNN) qui apprennent depuis des
labels pré-calculés, cette stratégie utilise un agent PPO (Proximal Policy
Optimization) entraîné via interaction directe avec un environnement de trading
(TradingEnv). L'agent apprend à maximiser la récompense cumulative (profit ajusté
par le risque) sans avoir besoin de labels explicites.
Fonctionnement :
1. Le modèle RLStrategyModel est chargé depuis le disque
(entraîné via POST /trading/train-rl)
2. À chaque barre, les seq_len dernières bougies sont fournies à l'agent
3. L'agent PPO retourne une action : LONG (1) / SHORT (-1) / NEUTRAL (0)
avec un score de confiance basé sur les probabilités de la politique
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
Intégration :
- Compatible avec StrategyEngine (même interface que CNNImageDrivenStrategy)
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
- Requiert PyTorch (CPU ou CUDA)
"""
import logging
from datetime import datetime, timezone
from typing import Dict, Optional
import numpy as np
import pandas as pd
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
# Import conditionnel — RL_AVAILABLE reste False si PyTorch ou stable-baselines3
# ne sont pas installés dans le container
try:
from src.ml.rl.rl_strategy_model import RLStrategyModel
RL_AVAILABLE = True
except ImportError:
RLStrategyModel = None
RL_AVAILABLE = False
logger = logging.getLogger(__name__)
class RLDrivenStrategy(BaseStrategy):
"""
Stratégie de trading pilotée par un agent PPO (Reinforcement Learning).
L'agent apprend à maximiser le profit cumulatif en interagissant avec un
environnement de trading simulé (TradingEnv). Aucun label supervisé n'est
nécessaire : la récompense est définie par le P&L réalisé et les pénalités
de risque (drawdown, over-trading).
Args:
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
Config keys supplémentaires (optionnelles) :
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
seq_len: Fenêtre d'observation RL (barres) (défaut: 20)
auto_load: Charger automatiquement le modèle existant (défaut: True)
"""
STRATEGY_NAME = 'rl_driven'
def __init__(self, config: Dict):
super().__init__(config)
self.symbol = config.get('symbol', 'EURUSD')
self.min_confidence = config.get('min_confidence', 0.55)
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
self.seq_len = config.get('seq_len', 20)
self.rl_model: Optional['RLStrategyModel'] = None
if not RL_AVAILABLE:
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
return
# Tentative de chargement automatique du modèle existant
if config.get('auto_load', True):
self._try_load_model()
# -------------------------------------------------------------------------
# Interface BaseStrategy
# -------------------------------------------------------------------------
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
"""
Génère un signal de trading via l'agent PPO.
Args:
market_data: DataFrame OHLCV (minimum seq_len barres)
Returns:
Signal si l'agent est confiant, None sinon
"""
if not RL_AVAILABLE:
logger.debug("RL Strategy : PyTorch / stable-baselines3 non disponible, aucun signal")
return None
if self.rl_model is None or not self.rl_model.is_trained:
logger.debug("RL Strategy : modèle non chargé, aucun signal")
return None
if len(market_data) < self.seq_len:
return None
try:
result = self.rl_model.predict(market_data)
except Exception as e:
logger.warning(f"RL Strategy predict error : {e}")
return None
if not result.get('tradeable', False):
return None
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
confidence = result['confidence']
# Prix et ATR pour le calcul des niveaux SL/TP
last_close = float(market_data['close'].iloc[-1])
atr = self._compute_atr(market_data)
if atr <= 0:
return None
if signal_dir == 1:
direction = 'LONG'
stop_loss = last_close - self.sl_atr_mult * atr
take_profit = last_close + self.tp_atr_mult * atr
elif signal_dir == -1:
direction = 'SHORT'
stop_loss = last_close + self.sl_atr_mult * atr
take_profit = last_close - self.tp_atr_mult * atr
else:
return None
signal = Signal(
symbol = self.symbol,
direction = direction,
entry_price = last_close,
stop_loss = stop_loss,
take_profit = take_profit,
confidence = confidence,
timestamp = datetime.now(timezone.utc),
strategy = self.STRATEGY_NAME,
metadata = {
'probas': result.get('probas', {}),
'seq_len': self.seq_len,
'atr': atr,
'tp_atr_mult': self.tp_atr_mult,
'sl_atr_mult': self.sl_atr_mult,
'avg_reward': result.get('avg_reward'),
'total_timesteps': result.get('total_timesteps'),
},
)
logger.info(
f"RL Signal : {direction} {self.symbol} | "
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
f"confidence={confidence:.2%}"
)
return signal
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
"""Retourne les données telles quelles — l'agent RL travaille sur les OHLCV bruts."""
return data
def update_params(self, params: Dict) -> None:
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
if 'min_confidence' in params:
self.min_confidence = params['min_confidence']
if self.rl_model:
self.rl_model.min_confidence = params['min_confidence']
if 'tp_atr_mult' in params:
self.tp_atr_mult = params['tp_atr_mult']
if 'sl_atr_mult' in params:
self.sl_atr_mult = params['sl_atr_mult']
if 'seq_len' in params:
self.seq_len = params['seq_len']
logger.info(f"RL Strategy params mis à jour : {params}")
# -------------------------------------------------------------------------
# Gestion du modèle
# -------------------------------------------------------------------------
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
"""
Charge un modèle RL depuis le disque.
Args:
symbol: Paire (défaut: self.symbol)
timeframe: Timeframe (défaut: self.config.timeframe)
Returns:
True si chargement réussi
"""
if not RL_AVAILABLE:
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
return False
sym = symbol or self.symbol
tf = timeframe or self.config.timeframe
try:
self.rl_model = RLStrategyModel.load(sym, tf)
logger.info(f"Modèle RL chargé : {sym}/{tf}")
return True
except FileNotFoundError:
logger.info(f"Aucun modèle RL trouvé pour {sym}/{tf}")
return False
except Exception as e:
logger.error(f"Erreur chargement modèle RL : {e}")
return False
def attach_model(self, model: 'RLStrategyModel') -> None:
"""Attache directement un modèle RL (après entraînement via API)."""
self.rl_model = model
self.symbol = model.symbol
logger.info(f"Modèle RL attaché : {model.symbol}/{model.timeframe}")
def is_ready(self) -> bool:
"""Retourne True si le modèle RL est chargé et entraîné."""
if not RL_AVAILABLE:
return False
return self.rl_model is not None and self.rl_model.is_trained
def get_model_info(self) -> Dict:
"""Retourne les métadonnées du modèle RL actif."""
if not RL_AVAILABLE:
return {'status': 'PyTorch / stable-baselines3 non disponible'}
if not self.is_ready():
return {'status': 'non entraîné'}
meta = self.rl_model.metadata.copy()
meta['is_ready'] = True
meta['seq_len'] = self.seq_len
return meta
# -------------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------------
def _try_load_model(self) -> None:
"""Tente un chargement silencieux du modèle au démarrage."""
try:
self.load_model()
except Exception:
pass
@staticmethod
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
"""Calcule l'ATR moyen sur les dernières barres."""
if len(df) < period + 1:
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
h, l, pc = df['high'], df['low'], df['close'].shift(1)
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
atr = tr.rolling(period).mean().iloc[-1]
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])