Phase 4c-bis/4d : CNN Image vectorisé + Agent RL PPO + HMM persistence + scripts
CNN Image (Phase 4c-bis) :
- chart_renderer.py : renderer numpy vectorisé (boucle 64 bougies, pas 12 000 fenêtres)
→ 1 068 img/s, GIL libéré entre itérations, API réactive pendant l'entraînement
- cnn_image_strategy_model.py : torch.set_num_threads(4) pour préserver l'event loop
- trading.py : asyncio.create_task() au lieu de background_tasks → hot-reloads non-bloquants
Agent RL PPO (Phase 4d) :
- src/ml/rl/ : TradingEnv (gymnasium), PPOModel (Actor-Critic MLP, GAE), RLStrategyModel
- src/strategies/rl_driven/ : RLDrivenStrategy (interface BaseStrategy complète)
- Routes API : POST /train-rl, GET /train-rl/{job_id}, GET /rl-models
- docs/RL_STRATEGY_GUIDE.md : documentation complète
HMM Persistence :
- regime_detector.py : save()/load()/needs_retrain()/is_trained (joblib + JSON meta)
- trading.py /ml/status : charge depuis disque si < 24h, re-entraîne + sauvegarde sinon
→ premier appel ~2s, appels suivants < 100ms
Scripts utilitaires :
- scripts/compare_strategies.py : backtest comparatif toutes stratégies (tabulate/JSON)
- scripts/quick_benchmark.py : comparaison wf_accuracy/precision des modèles ML sauvegardés
- reports/ : répertoire pour les rapports JSON générés
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
236
docs/RL_STRATEGY_GUIDE.md
Normal file
236
docs/RL_STRATEGY_GUIDE.md
Normal file
@@ -0,0 +1,236 @@
|
||||
# RL Strategy Guide — Phase 4d (Reinforcement Learning)
|
||||
|
||||
## Vue d'ensemble
|
||||
|
||||
La Phase 4d introduit une stratégie pilotée par un agent de **Reinforcement Learning (RL)**
|
||||
basé sur l'algorithme **PPO** (Proximal Policy Optimization). Contrairement aux stratégies
|
||||
supervisées (XGBoost, CNN) qui s'entraînent sur des labels pré-calculés, l'agent RL apprend
|
||||
directement par interaction avec un environnement de trading simulé (`TradingEnv`), sans avoir
|
||||
besoin d'annotations explicites LONG/SHORT/NEUTRAL.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### Composants (implémentés par l'agent ml/rl/)
|
||||
|
||||
```
|
||||
src/ml/rl/
|
||||
├── trading_env.py # TradingEnv (Gymnasium) — environnement de simulation
|
||||
├── ppo_model.py # PPOModel — réseau Actor-Critic (MLP)
|
||||
└── rl_strategy_model.py # RLStrategyModel — interface train/predict/save/load
|
||||
```
|
||||
|
||||
### Composants (cette phase)
|
||||
|
||||
```
|
||||
src/strategies/rl_driven/
|
||||
├── __init__.py # Export RLDrivenStrategy
|
||||
└── rl_strategy.py # RLDrivenStrategy (hérite BaseStrategy)
|
||||
```
|
||||
|
||||
### Pipeline d'entraînement PPO
|
||||
|
||||
```
|
||||
Données OHLCV
|
||||
│
|
||||
▼
|
||||
TradingEnv (Gymnasium)
|
||||
├── Observation space : fenêtre glissante seq_len=20 barres (OHLCV normalisé)
|
||||
├── Action space : {0=HOLD/NEUTRAL, 1=LONG, 2=SHORT}
|
||||
└── Reward : P&L réalisé − pénalité drawdown − pénalité over-trading
|
||||
│
|
||||
▼
|
||||
PPO Actor-Critic (MLP)
|
||||
├── Actor : policy π(a|s) → distribution sur les actions
|
||||
└── Critic : value function V(s) → estimation de la récompense future
|
||||
│
|
||||
▼
|
||||
Optimisation sur total_timesteps pas
|
||||
│
|
||||
▼
|
||||
RLStrategyModel.save()
|
||||
├── models/rl_strategy/EURUSD_1h.zip (politique PPO)
|
||||
└── models/rl_strategy/EURUSD_1h_meta.json
|
||||
```
|
||||
|
||||
### Architecture Actor-Critic
|
||||
|
||||
L'agent PPO utilise un réseau MLP (Multi-Layer Perceptron) à deux têtes :
|
||||
|
||||
- **Actor** : prédit la distribution de probabilité sur les 3 actions (LONG/SHORT/NEUTRAL).
|
||||
La confiance du signal correspond à `max(probas)`.
|
||||
- **Critic** : estime la valeur d'état V(s) pour calculer l'avantage (advantage) utilisé
|
||||
lors de la mise à jour de la politique.
|
||||
|
||||
Hyperparamètres PPO typiques :
|
||||
| Paramètre | Valeur par défaut | Description |
|
||||
|---------------|-------------------|----------------------------------------------|
|
||||
| `gamma` | 0.99 | Facteur d'actualisation des récompenses |
|
||||
| `clip_range` | 0.2 | Clipping ratio PPO (stabilité) |
|
||||
| `n_steps` | 2048 | Nombre de pas par rollout |
|
||||
| `batch_size` | 64 | Taille des mini-batches SGD |
|
||||
| `n_epochs` | 10 | Passes sur chaque rollout |
|
||||
| `ent_coef` | 0.01 | Coefficient d'entropie (exploration) |
|
||||
|
||||
---
|
||||
|
||||
## Lancer l'entraînement
|
||||
|
||||
### Via curl
|
||||
|
||||
```bash
|
||||
# Lancer l'entraînement (tâche de fond)
|
||||
curl -X POST http://localhost:8100/trading/train-rl \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"period": "2y",
|
||||
"total_timesteps": 50000
|
||||
}'
|
||||
|
||||
# Réponse
|
||||
{
|
||||
"job_id": "a1b2c3d4-...",
|
||||
"status": "pending",
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
# Suivre l'avancement
|
||||
curl http://localhost:8100/trading/train-rl/a1b2c3d4-...
|
||||
|
||||
# Réponse quand terminé
|
||||
{
|
||||
"job_id": "a1b2c3d4-...",
|
||||
"status": "completed",
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"avg_reward": 0.042,
|
||||
"sharpe_env": 1.23,
|
||||
"total_timesteps": 50000,
|
||||
"trained_at": "2026-03-10T14:32:00"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
# Lister les modèles disponibles
|
||||
curl http://localhost:8100/trading/rl-models
|
||||
|
||||
# Réponse
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"avg_reward": 0.042,
|
||||
"sharpe_env": 1.23,
|
||||
"total_timesteps": 50000,
|
||||
"trained_at": "2026-03-10T14:32:00"
|
||||
}
|
||||
],
|
||||
"count": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Paramètres de la requête
|
||||
|
||||
| Champ | Type | Défaut | Description |
|
||||
|-------------------|-------|----------|---------------------------------------------------------|
|
||||
| `symbol` | str | EURUSD | Paire de trading (ex: EURUSD, BTCUSDT) |
|
||||
| `timeframe` | str | 1h | Timeframe (1m, 5m, 15m, 1h, 4h, 1d) |
|
||||
| `period` | str | 2y | Historique d'entraînement (ex: 6m, 1y, 2y) |
|
||||
| `total_timesteps` | int | 50000 | Nombre total de pas de simulation PPO |
|
||||
|
||||
**Recommandations `total_timesteps` :**
|
||||
- 20 000 : entraînement rapide (test, ~5 min CPU)
|
||||
- 50 000 : entraînement standard (défaut, ~15 min CPU)
|
||||
- 200 000 : entraînement long, meilleure convergence (~1h CPU)
|
||||
- 1 000 000 : entraînement poussé si GPU disponible
|
||||
|
||||
---
|
||||
|
||||
## Interpréter les métriques
|
||||
|
||||
### `avg_reward`
|
||||
Récompense moyenne par pas de simulation sur les derniers rollouts d'évaluation.
|
||||
|
||||
- **< 0** : l'agent perd de l'argent en simulation → entraîner plus longtemps ou revoir la fonction de récompense
|
||||
- **0 à 0.02** : agent neutre, légèrement profitable
|
||||
- **> 0.05** : bon signal → tester en backtest réel
|
||||
- **> 0.1** : excellent (attention au sur-apprentissage, vérifier out-of-sample)
|
||||
|
||||
### `sharpe_env`
|
||||
Ratio de Sharpe calculé sur les épisodes de simulation (récompenses / écart-type des récompenses).
|
||||
|
||||
- **< 0.5** : insuffisant pour paper trading
|
||||
- **0.5 – 1.0** : acceptable, à valider en backtest
|
||||
- **> 1.5** : cible pour activation paper trading (conforme aux seuils du projet)
|
||||
|
||||
### Interprétation combinée
|
||||
|
||||
| `avg_reward` | `sharpe_env` | Interprétation |
|
||||
|--------------|--------------|--------------------------------------------------------|
|
||||
| négatif | quelconque | Agent non convergé — relancer avec plus de timesteps |
|
||||
| 0 – 0.02 | < 1.0 | Apprentissage partiel — augmenter total_timesteps |
|
||||
| > 0.03 | > 1.0 | Bon candidat — valider via POST /trading/backtest |
|
||||
| > 0.05 | > 1.5 | Prêt pour paper trading (30 jours minimum) |
|
||||
|
||||
---
|
||||
|
||||
## Intégration avec le paper trading
|
||||
|
||||
Après l'entraînement, si un paper trading avec la stratégie `rl_driven` est actif,
|
||||
le modèle est **automatiquement attaché** sans redémarrage.
|
||||
|
||||
Pour démarrer un paper trading RL :
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8100/trading/paper/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"strategy": "rl_driven", "symbol": "EURUSD"}'
|
||||
```
|
||||
|
||||
La stratégie `RLDrivenStrategy` :
|
||||
1. Charge le dernier modèle entraîné pour le symbole/timeframe
|
||||
2. À chaque barre, fournit les 20 dernières bougies à l'agent (`seq_len=20`)
|
||||
3. Si la confiance de l'agent >= `min_confidence` (défaut: 0.55), émet un signal
|
||||
4. SL = `sl_atr_mult × ATR` (défaut: 1×ATR), TP = `tp_atr_mult × ATR` (défaut: 2×ATR)
|
||||
|
||||
---
|
||||
|
||||
## Notes techniques
|
||||
|
||||
### Threads PyTorch
|
||||
L'entraînement fixe `torch.set_num_threads(4)` pour éviter la contention CPU dans
|
||||
le container Docker. Adapter dans `docker-compose.yml` si le container dispose de plus de cœurs.
|
||||
|
||||
### Sauvegarde des modèles
|
||||
Les modèles sont sauvegardés dans `models/rl_strategy/` (volume Docker monté) :
|
||||
- `EURUSD_1h.zip` : politique PPO (format stable-baselines3)
|
||||
- `EURUSD_1h_meta.json` : métadonnées (métriques, hyperparamètres, date)
|
||||
|
||||
### Import conditionnel
|
||||
Le flag `RL_AVAILABLE` est `False` si PyTorch ou stable-baselines3 ne sont pas
|
||||
installés. La stratégie dégrade gracieusement (aucun signal, aucune exception).
|
||||
Pour activer : ajouter `stable-baselines3` et `gymnasium` dans `docker/requirements/api.txt`
|
||||
puis reconstruire le container (`docker compose build --no-cache trading-api`).
|
||||
|
||||
---
|
||||
|
||||
## Seuils de validation (conforme au projet)
|
||||
|
||||
Avant activation du live trading, la stratégie RL doit satisfaire :
|
||||
|
||||
| Métrique | Seuil minimum |
|
||||
|-----------------|---------------|
|
||||
| Sharpe Ratio | ≥ 1.5 |
|
||||
| Max Drawdown | ≤ 10% |
|
||||
| Win Rate | ≥ 55% |
|
||||
| Paper Trading | ≥ 30 jours |
|
||||
|
||||
Ces seuils s'appliquent au **backtest out-of-sample** et au **paper trading**, pas
|
||||
aux métriques de simulation RL (`sharpe_env`) qui sont indicatives uniquement.
|
||||
365
scripts/compare_strategies.py
Executable file
365
scripts/compare_strategies.py
Executable file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
compare_strategies.py — Backtest comparatif automatisé des stratégies de trading.
|
||||
|
||||
Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven)
|
||||
sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau
|
||||
comparatif et sauvegarde le rapport JSON dans reports/.
|
||||
|
||||
Usage :
|
||||
python scripts/compare_strategies.py
|
||||
python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vérification de la dépendance httpx (ou fallback vers requests)
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import httpx
|
||||
_HTTP_BACKEND = "httpx"
|
||||
except ImportError:
|
||||
try:
|
||||
import requests as _requests_module
|
||||
_HTTP_BACKEND = "requests"
|
||||
except ImportError:
|
||||
print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx")
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
API_BASE_URL = "http://localhost:8100"
|
||||
POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes)
|
||||
TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes)
|
||||
REPORTS_DIR = Path(__file__).parent.parent / "reports"
|
||||
|
||||
# Stratégies à comparer — dans l'ordre d'affichage.
|
||||
# NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven
|
||||
# La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest
|
||||
# (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement
|
||||
# si l'API évolue pour la supporter.
|
||||
STRATEGIES = [
|
||||
"scalping",
|
||||
"ml_driven",
|
||||
"cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement
|
||||
]
|
||||
|
||||
# Colonnes métriques à collecter et afficher
|
||||
METRICS_COLS = [
|
||||
("total_return", "Retour total", "{:.2%}"),
|
||||
("sharpe_ratio", "Sharpe ratio", "{:.3f}"),
|
||||
("max_drawdown", "Max drawdown", "{:.2%}"),
|
||||
("win_rate", "Win rate", "{:.2%}"),
|
||||
("total_trades", "Trades totaux", "{:d}"),
|
||||
("profit_factor", "Profit factor", "{:.2f}"),
|
||||
]
|
||||
|
||||
# Seuils de validation (identiques à ceux de l'API)
|
||||
SEUIL_SHARPE = 1.5
|
||||
SEUIL_DRAWDOWN_MAX = 0.10
|
||||
SEUIL_WIN_RATE = 0.55
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Couche HTTP (httpx ou requests selon disponibilité)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _post(url: str, payload: dict, timeout: int = 30) -> dict:
|
||||
"""Effectue un POST JSON et retourne la réponse parsée."""
|
||||
if _HTTP_BACKEND == "httpx":
|
||||
resp = httpx.post(url, json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
else:
|
||||
resp = _requests_module.post(url, json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _get(url: str, timeout: int = 15) -> dict:
|
||||
"""Effectue un GET et retourne la réponse parsée."""
|
||||
if _HTTP_BACKEND == "httpx":
|
||||
resp = httpx.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
else:
|
||||
resp = _requests_module.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fonctions backtest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str:
|
||||
"""
|
||||
Lance un job de backtest via POST /trading/backtest.
|
||||
|
||||
Retourne le job_id ou lève une exception en cas d'erreur.
|
||||
"""
|
||||
payload = {
|
||||
"strategy": strategy,
|
||||
"symbol": symbol,
|
||||
"period": period,
|
||||
"initial_capital": capital,
|
||||
}
|
||||
url = f"{API_BASE_URL}/trading/backtest"
|
||||
data = _post(url, payload)
|
||||
return data["job_id"]
|
||||
|
||||
|
||||
def attendre_resultat(job_id: str, strategy: str) -> dict:
|
||||
"""
|
||||
Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes
|
||||
jusqu'à ce que le statut soit 'completed' ou 'failed'.
|
||||
|
||||
Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout.
|
||||
"""
|
||||
url = f"{API_BASE_URL}/trading/backtest/{job_id}"
|
||||
debut = time.time()
|
||||
|
||||
while True:
|
||||
elapsed = time.time() - debut
|
||||
if elapsed > TIMEOUT_SEC:
|
||||
raise TimeoutError(
|
||||
f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours."
|
||||
)
|
||||
|
||||
data = _get(url)
|
||||
statut = data.get("status", "?")
|
||||
|
||||
print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)")
|
||||
|
||||
if statut == "completed":
|
||||
return data
|
||||
elif statut == "failed":
|
||||
erreur = data.get("error", "raison inconnue")
|
||||
raise RuntimeError(f"[{strategy}] Job échoué : {erreur}")
|
||||
|
||||
# Toujours en cours — on attend
|
||||
time.sleep(POLL_INTERVAL_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Affichage du tableau comparatif
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def afficher_tableau(resultats: list[dict]) -> None:
|
||||
"""
|
||||
Affiche un tableau comparatif dans le terminal.
|
||||
Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple.
|
||||
"""
|
||||
# Construction des données tabulaires
|
||||
headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"]
|
||||
rows = []
|
||||
|
||||
for res in resultats:
|
||||
nom = res["strategy"]
|
||||
ligne = [nom]
|
||||
for key, _, fmt in METRICS_COLS:
|
||||
val = res.get(key)
|
||||
if val is None:
|
||||
ligne.append("N/A")
|
||||
else:
|
||||
try:
|
||||
if key == "total_trades":
|
||||
ligne.append(fmt.format(int(val)))
|
||||
else:
|
||||
ligne.append(fmt.format(float(val)))
|
||||
except (ValueError, TypeError):
|
||||
ligne.append(str(val))
|
||||
# Indicateur de validation
|
||||
valide = res.get("is_valid_for_paper")
|
||||
if valide is True:
|
||||
ligne.append("OUI")
|
||||
elif valide is False:
|
||||
ligne.append("NON")
|
||||
else:
|
||||
ligne.append("?")
|
||||
rows.append(ligne)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("RAPPORT COMPARATIF DES STRATÉGIES")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||
except ImportError:
|
||||
# Fallback : affichage simple sans tabulate
|
||||
col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||
for i, h in enumerate(headers)]
|
||||
sep = " ".join("-" * w for w in col_widths)
|
||||
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||
print(header_line)
|
||||
print(sep)
|
||||
for row in rows:
|
||||
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||
|
||||
print()
|
||||
|
||||
# Identification de la meilleure stratégie (selon Sharpe ratio)
|
||||
valides = [r for r in resultats if r.get("sharpe_ratio") is not None]
|
||||
if valides:
|
||||
meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999)
|
||||
print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} "
|
||||
f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})")
|
||||
|
||||
# Rappel des seuils
|
||||
print()
|
||||
print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | "
|
||||
f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | "
|
||||
f"Win rate >= {SEUIL_WIN_RATE:.0%}")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sauvegarde du rapport JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path:
|
||||
"""
|
||||
Sauvegarde le rapport de comparaison en JSON dans reports/.
|
||||
|
||||
Retourne le chemin du fichier créé.
|
||||
"""
|
||||
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
horodatage = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
nom_fichier = f"backtest_comparison_{horodatage}.json"
|
||||
chemin = REPORTS_DIR / nom_fichier
|
||||
|
||||
rapport = {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"parametres": {
|
||||
"symbol": symbol,
|
||||
"period": period,
|
||||
"api_base_url": API_BASE_URL,
|
||||
},
|
||||
"seuils_validation": {
|
||||
"sharpe_ratio_min": SEUIL_SHARPE,
|
||||
"max_drawdown_max": SEUIL_DRAWDOWN_MAX,
|
||||
"win_rate_min": SEUIL_WIN_RATE,
|
||||
},
|
||||
"resultats": resultats,
|
||||
}
|
||||
|
||||
with open(chemin, "w", encoding="utf-8") as f:
|
||||
json.dump(rapport, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
return chemin
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Point d'entrée principal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse les arguments en ligne de commande."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare les stratégies de trading via backtest API."
|
||||
)
|
||||
parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)")
|
||||
parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)")
|
||||
parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)")
|
||||
parser.add_argument(
|
||||
"--strategies",
|
||||
nargs="+",
|
||||
default=STRATEGIES,
|
||||
help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("BACKTEST COMPARATIF DES STRATÉGIES")
|
||||
print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR")
|
||||
print(f"Stratégies: {', '.join(args.strategies)}")
|
||||
print(f"API: {API_BASE_URL}")
|
||||
print("=" * 70)
|
||||
|
||||
# Vérification de la disponibilité de l'API
|
||||
try:
|
||||
_get(f"{API_BASE_URL}/health")
|
||||
print("API disponible.")
|
||||
except Exception as e:
|
||||
print(f"ERREUR : Impossible de joindre l'API ({e})")
|
||||
print("Vérifiez que le container trading-api tourne sur le port 8100.")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1 : lancement de tous les backtests
|
||||
# -----------------------------------------------------------------------
|
||||
jobs: dict[str, str] = {} # {strategy: job_id}
|
||||
stratégies_échouées: list[str] = []
|
||||
|
||||
for strat in args.strategies:
|
||||
print(f"Lancement backtest [{strat}]...")
|
||||
try:
|
||||
job_id = lancer_backtest(strat, args.symbol, args.period, args.capital)
|
||||
jobs[strat] = job_id
|
||||
print(f" -> Job ID : {job_id}")
|
||||
except Exception as e:
|
||||
print(f" ERREUR lancement [{strat}] : {e}")
|
||||
stratégies_échouées.append(strat)
|
||||
|
||||
if not jobs:
|
||||
print("Aucun job lancé. Arrêt.")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 2 : attente et collecte des résultats (séquentielle)
|
||||
# -----------------------------------------------------------------------
|
||||
resultats: list[dict] = []
|
||||
|
||||
for strat, job_id in jobs.items():
|
||||
print(f"Attente résultat [{strat}] (job={job_id})...")
|
||||
try:
|
||||
res = attendre_resultat(job_id, strat)
|
||||
resultats.append(res)
|
||||
print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, "
|
||||
f"Retour={res.get('total_return', 'N/A')}")
|
||||
except Exception as e:
|
||||
print(f" ERREUR [{strat}] : {e}")
|
||||
# On ajoute quand même un résultat partiel pour la lisibilité du rapport
|
||||
resultats.append({
|
||||
"strategy": strat,
|
||||
"symbol": args.symbol,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 3 : affichage et sauvegarde
|
||||
# -----------------------------------------------------------------------
|
||||
afficher_tableau(resultats)
|
||||
|
||||
chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period)
|
||||
print(f"Rapport JSON sauvegardé : {chemin_rapport}")
|
||||
print()
|
||||
|
||||
# Résumé des échecs éventuels
|
||||
if stratégies_échouées:
|
||||
print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
327
scripts/quick_benchmark.py
Executable file
327
scripts/quick_benchmark.py
Executable file
@@ -0,0 +1,327 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
quick_benchmark.py — Benchmark rapide des modèles ML entraînés.
|
||||
|
||||
Lit les fichiers *_meta.json dans les répertoires de modèles :
|
||||
- models/ml_strategy/ (XGBoost / LightGBM / RandomForest)
|
||||
- models/cnn_strategy/ (CNN 1D — CandlestickEncoder)
|
||||
- models/cnn_image_strategy/ (CNN Image — chandeliers en image)
|
||||
|
||||
Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc.
|
||||
Indique quel modèle obtient la meilleure précision walk-forward.
|
||||
|
||||
Usage :
|
||||
python scripts/quick_benchmark.py
|
||||
python scripts/quick_benchmark.py --models-root /chemin/vers/models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Répertoires de modèles (relatifs à la racine du projet)
|
||||
# ---------------------------------------------------------------------------
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
|
||||
# Mapping : nom affiché -> chemin relatif au projet
|
||||
MODEL_DIRS = {
|
||||
"ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy",
|
||||
"CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy",
|
||||
"CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy",
|
||||
}
|
||||
|
||||
# Colonnes de métriques affichées dans le tableau
|
||||
COLONNES = [
|
||||
("type", "Type", "{:<28}"),
|
||||
("symbol", "Symbol", "{:<8}"),
|
||||
("timeframe", "TF", "{:<5}"),
|
||||
("model_type", "Modèle", "{:<12}"),
|
||||
("n_samples", "Échantillons", "{:>12}"),
|
||||
("wf_accuracy", "WF Accuracy", "{:>12}"),
|
||||
("wf_precision", "WF Precision", "{:>13}"),
|
||||
("trained_at", "Entraîné le", "{:<20}"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lecture des méta-fichiers JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]:
|
||||
"""
|
||||
Parcourt les répertoires de modèles et charge les métadonnées JSON.
|
||||
|
||||
Retourne une liste de dicts prêts pour l'affichage.
|
||||
"""
|
||||
tous = []
|
||||
|
||||
for type_label, dossier in model_dirs.items():
|
||||
if not dossier.exists():
|
||||
# Répertoire absent — pas encore de modèles entraînés pour ce type
|
||||
continue
|
||||
|
||||
fichiers_meta = sorted(dossier.glob("*_meta.json"))
|
||||
if not fichiers_meta:
|
||||
continue
|
||||
|
||||
for f in fichiers_meta:
|
||||
try:
|
||||
with open(f, encoding="utf-8") as fp:
|
||||
meta = json.load(fp)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
wf = meta.get("wf_metrics", {})
|
||||
|
||||
# Extraction des champs — avec valeurs par défaut
|
||||
# Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé
|
||||
# comme fallback si les champs ne sont pas dans le JSON.
|
||||
stem_parts = f.stem.replace("_meta", "").split("_")
|
||||
symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?")
|
||||
timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?")
|
||||
model_type = meta.get(
|
||||
"model_type",
|
||||
stem_parts[2] if len(stem_parts) > 2 else dossier.name
|
||||
)
|
||||
|
||||
wf_accuracy = wf.get("avg_accuracy", None)
|
||||
wf_precision = wf.get("avg_precision", None)
|
||||
|
||||
trained_at = meta.get("trained_at", "?")
|
||||
# Tronque la date ISO à 19 caractères pour la lisibilité
|
||||
if isinstance(trained_at, str) and len(trained_at) > 19:
|
||||
trained_at = trained_at[:19]
|
||||
|
||||
tous.append({
|
||||
"type": type_label,
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"model_type": model_type,
|
||||
"n_samples": meta.get("n_samples", 0),
|
||||
"wf_accuracy": wf_accuracy,
|
||||
"wf_precision": wf_precision,
|
||||
"trained_at": trained_at,
|
||||
"_meta_path": str(f), # pour le rapport détaillé éventuel
|
||||
"_raw_meta": meta,
|
||||
})
|
||||
|
||||
return tous
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Affichage du tableau
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def formater_val(val, fmt: str) -> str:
|
||||
"""Formate une valeur numérique ou retourne 'N/A'."""
|
||||
if val is None:
|
||||
return "N/A"
|
||||
try:
|
||||
if "%" in fmt or "f" in fmt:
|
||||
return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val))
|
||||
return fmt.format(val)
|
||||
except (ValueError, TypeError):
|
||||
return str(val)
|
||||
|
||||
|
||||
def afficher_tableau(modeles: list[dict]) -> None:
|
||||
"""Affiche le tableau comparatif des modèles dans le terminal."""
|
||||
if not modeles:
|
||||
print("Aucun modèle entraîné trouvé.")
|
||||
print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).")
|
||||
return
|
||||
|
||||
# Entêtes
|
||||
headers = [label for _, label, _ in COLONNES]
|
||||
col_fmts = [fmt for _, _, fmt in COLONNES]
|
||||
|
||||
# Construction des lignes
|
||||
rows = []
|
||||
for m in modeles:
|
||||
ligne = []
|
||||
for key, _, fmt in COLONNES:
|
||||
val = m.get(key)
|
||||
if val is None:
|
||||
ligne.append("N/A")
|
||||
elif key in ("wf_accuracy", "wf_precision"):
|
||||
# Affichage en pourcentage
|
||||
try:
|
||||
ligne.append(f"{float(val):.2%}")
|
||||
except (ValueError, TypeError):
|
||||
ligne.append("N/A")
|
||||
elif key == "n_samples":
|
||||
try:
|
||||
ligne.append(f"{int(val):,}")
|
||||
except (ValueError, TypeError):
|
||||
ligne.append(str(val))
|
||||
else:
|
||||
ligne.append(str(val))
|
||||
rows.append(ligne)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||
except ImportError:
|
||||
# Fallback : affichage simple
|
||||
col_widths = [
|
||||
max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
sep = " ".join("-" * w for w in col_widths)
|
||||
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||
print(header_line)
|
||||
print(sep)
|
||||
for row in rows:
|
||||
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Identification du meilleur modèle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def identifier_meilleur(modeles: list[dict]) -> None:
|
||||
"""Indique quel modèle est le meilleur selon wf_accuracy et wf_precision."""
|
||||
candidats = [m for m in modeles if m.get("wf_accuracy") is not None]
|
||||
if not candidats:
|
||||
print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).")
|
||||
return
|
||||
|
||||
# Critère principal : wf_accuracy ; critère secondaire : wf_precision
|
||||
meilleur = max(
|
||||
candidats,
|
||||
key=lambda m: (
|
||||
m.get("wf_accuracy") or 0.0,
|
||||
m.get("wf_precision") or 0.0,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Meilleur modèle (WF Accuracy) :")
|
||||
print(f" Type : {meilleur['type']}")
|
||||
print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}")
|
||||
print(f" Modèle : {meilleur['model_type']}")
|
||||
|
||||
acc = meilleur.get("wf_accuracy")
|
||||
prec = meilleur.get("wf_precision")
|
||||
print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A")
|
||||
print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A")
|
||||
print(f" Entraîné le : {meilleur.get('trained_at', '?')}")
|
||||
print(f" Fichier meta : {meilleur['_meta_path']}")
|
||||
|
||||
# Avertissement si la précision est insuffisante pour le trading
|
||||
if acc is not None and acc < 0.40:
|
||||
print()
|
||||
print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.")
|
||||
print(" Envisagez un re-entraînement avec plus de données ou de features.")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Statistiques globales par type de modèle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def afficher_stats_globales(modeles: list[dict]) -> None:
|
||||
"""Affiche des statistiques agrégées par type de modèle."""
|
||||
if not modeles:
|
||||
return
|
||||
|
||||
# Regroupement par type
|
||||
par_type: dict[str, list] = {}
|
||||
for m in modeles:
|
||||
par_type.setdefault(m["type"], []).append(m)
|
||||
|
||||
print("Résumé par type de modèle :")
|
||||
print("-" * 50)
|
||||
for type_label, groupe in par_type.items():
|
||||
accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None]
|
||||
precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None]
|
||||
n_tot = sum(m.get("n_samples", 0) for m in groupe)
|
||||
|
||||
acc_moy = sum(accs) / len(accs) if accs else None
|
||||
prec_moy = sum(precs) / len(precs) if precs else None
|
||||
|
||||
acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A"
|
||||
prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A"
|
||||
|
||||
print(f" {type_label}")
|
||||
print(f" Modèles entraînés : {len(groupe)}")
|
||||
print(f" WF Accuracy moy. : {acc_str}")
|
||||
print(f" WF Precision moy. : {prec_str}")
|
||||
print(f" Échantillons tot. : {n_tot:,}")
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Point d'entrée principal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse les arguments en ligne de commande."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models-root",
|
||||
type=Path,
|
||||
default=PROJECT_ROOT / "models",
|
||||
help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Mise à jour des chemins si --models-root est spécifié
|
||||
dirs_effectifs = {
|
||||
label: args.models_root / dossier.name
|
||||
for label, dossier in MODEL_DIRS.items()
|
||||
}
|
||||
|
||||
print()
|
||||
print("Lecture des modèles dans :")
|
||||
for label, chemin in dirs_effectifs.items():
|
||||
existe = "OK" if chemin.exists() else "absent"
|
||||
print(f" [{existe}] {chemin}")
|
||||
print()
|
||||
|
||||
# Chargement de tous les modèles
|
||||
modeles = charger_modeles(dirs_effectifs)
|
||||
|
||||
if not modeles:
|
||||
print("Aucun modèle trouvé dans les répertoires indiqués.")
|
||||
print()
|
||||
print("Pour entraîner un modèle XGBoost/LightGBM :")
|
||||
print(" POST http://localhost:8100/trading/train")
|
||||
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}")
|
||||
print()
|
||||
print("Pour entraîner un modèle CNN 1D :")
|
||||
print(" POST http://localhost:8100/trading/train-cnn")
|
||||
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}")
|
||||
sys.exit(0)
|
||||
|
||||
# Affichage du tableau
|
||||
afficher_tableau(modeles)
|
||||
|
||||
# Statistiques par type
|
||||
afficher_stats_globales(modeles)
|
||||
|
||||
# Meilleur modèle
|
||||
identifier_meilleur(modeles)
|
||||
|
||||
print(f"Total : {len(modeles)} modèle(s) trouvé(s).")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -262,17 +262,46 @@ def get_ml_status(symbol: str = "EURUSD"):
|
||||
return cached["result"]
|
||||
|
||||
try:
|
||||
from src.ml.regime_detector import RegimeDetector
|
||||
|
||||
config = ConfigLoader.load_all()
|
||||
data_service = DataService(config)
|
||||
|
||||
timeframe = "1h"
|
||||
now = datetime.now()
|
||||
start = now - timedelta(days=30)
|
||||
|
||||
# Récupérer données synchrones via asyncio.run
|
||||
# ------------------------------------------------------------------
|
||||
# Tentative de chargement du modèle HMM persisté (évite le re-train)
|
||||
# ------------------------------------------------------------------
|
||||
detector = RegimeDetector(n_regimes=4)
|
||||
hmm_loaded = detector.load(symbol, timeframe)
|
||||
|
||||
if hmm_loaded and not detector.needs_retrain(max_age_hours=24):
|
||||
# Modèle récent disponible sur disque — pas besoin de ré-entraîner
|
||||
logger.info(
|
||||
f"Modèle HMM chargé depuis le disque pour {symbol}/{timeframe} "
|
||||
f"(entraîné le {detector._trained_at})"
|
||||
)
|
||||
need_fit = False
|
||||
else:
|
||||
if hmm_loaded:
|
||||
logger.info(
|
||||
f"Modèle HMM trop ancien pour {symbol}/{timeframe} — ré-entraînement nécessaire"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Aucun modèle HMM persisté pour {symbol}/{timeframe} — entraînement initial"
|
||||
)
|
||||
need_fit = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Récupération des données (toujours nécessaire pour la prédiction)
|
||||
# ------------------------------------------------------------------
|
||||
df = asyncio.run(
|
||||
data_service.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe="1h",
|
||||
timeframe=timeframe,
|
||||
start_date=start,
|
||||
end_date=now,
|
||||
)
|
||||
@@ -291,11 +320,23 @@ def get_ml_status(symbol: str = "EURUSD"):
|
||||
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entraînement si nécessaire puis sauvegarde du modèle
|
||||
# ------------------------------------------------------------------
|
||||
ml = MLEngine(config=config.get("ml", {}))
|
||||
ml.initialize(df)
|
||||
|
||||
if need_fit:
|
||||
# Entraîner le détecteur et l'injecter dans MLEngine
|
||||
detector.fit(df)
|
||||
# Sauvegarder le nouveau modèle sur disque pour les prochains appels
|
||||
detector.save(symbol, timeframe)
|
||||
|
||||
# Injecter le détecteur (chargé ou fraîchement entraîné) dans MLEngine
|
||||
ml.regime_detector = detector
|
||||
ml.current_regime = detector.predict_current_regime(df)
|
||||
|
||||
regime_info = ml.get_regime_info()
|
||||
regime_stats = ml.regime_detector.get_regime_statistics(df)
|
||||
regime_stats = detector.get_regime_statistics(df)
|
||||
|
||||
strategy_advice = {
|
||||
s: ml.should_trade(s)
|
||||
@@ -1383,7 +1424,9 @@ async def train_cnn_image(request: CNNImageTrainRequest, background_tasks: Backg
|
||||
"timeframe": request.timeframe,
|
||||
}
|
||||
|
||||
background_tasks.add_task(_run_cnn_image_train_task, job_id, request)
|
||||
# asyncio.create_task() plutôt que background_tasks pour permettre
|
||||
# les hot-reloads WatchFiles sans bloquer à l'arrêt du serveur
|
||||
asyncio.create_task(_run_cnn_image_train_task(job_id, request))
|
||||
|
||||
return CNNImageTrainResponse(
|
||||
job_id = job_id,
|
||||
@@ -1418,3 +1461,388 @@ def list_cnn_image_models():
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
models = CNNImageStrategyModel.list_trained_models()
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RL STRATEGY — Entraînement et gestion des modèles RL (PPO)
|
||||
# =============================================================================
|
||||
|
||||
try:
|
||||
from src.ml.rl import RLStrategyModel, RL_AVAILABLE as _RL_AVAILABLE
|
||||
RL_AVAILABLE = _RL_AVAILABLE
|
||||
except ImportError:
|
||||
RL_AVAILABLE = False
|
||||
|
||||
# Stockage en mémoire des jobs d'entraînement RL
|
||||
_rl_train_jobs: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class RLTrainRequest(BaseModel):
|
||||
"""Requête d'entraînement du modèle RL (PPO)."""
|
||||
symbol: str = "EURUSD"
|
||||
timeframe: str = "1h"
|
||||
period: str = "2y" # Période de données historiques
|
||||
total_timesteps: int = 100_000 # Nombre de timesteps PPO
|
||||
sl_atr_mult: float = 1.0 # Multiplicateur ATR pour SL
|
||||
tp_atr_mult: float = 2.0 # Multiplicateur ATR pour TP
|
||||
min_confidence: float = 0.50 # Seuil confiance minimum
|
||||
train_ratio: float = 0.80 # Fraction données pour l'entraînement
|
||||
initial_capital: float = 10_000.0 # Capital initial pour la simulation
|
||||
|
||||
|
||||
class RLTrainResponse(BaseModel):
|
||||
"""Réponse d'un job d'entraînement RL."""
|
||||
job_id: str
|
||||
status: str # pending | running | completed | failed
|
||||
symbol: Optional[str] = None
|
||||
timeframe: Optional[str] = None
|
||||
n_samples: Optional[int] = None
|
||||
total_timesteps: Optional[int] = None
|
||||
n_episodes: Optional[int] = None
|
||||
mean_ep_return: Optional[float] = None # Récompense moyenne par épisode
|
||||
eval_return_pct: Optional[float] = None # Rendement sur le holdout
|
||||
eval_sharpe: Optional[float] = None # Sharpe approx. sur le holdout
|
||||
eval_win_rate: Optional[float] = None # Win rate sur le holdout
|
||||
trained_at: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
async def _run_rl_train_task(job_id: str, request: RLTrainRequest) -> None:
|
||||
"""Tâche d'entraînement RL exécutée en arrière-plan."""
|
||||
_rl_train_jobs[job_id]["status"] = "running"
|
||||
try:
|
||||
from src.data.data_service import DataService
|
||||
from src.utils.config_loader import ConfigLoader
|
||||
from datetime import timedelta
|
||||
|
||||
config = ConfigLoader.load_all()
|
||||
data_service = DataService(config)
|
||||
|
||||
end_date = datetime.now()
|
||||
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||
unit = request.period[-1]
|
||||
value = int(request.period[:-1])
|
||||
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||
|
||||
df = await data_service.get_historical_data(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
start_date = start_date,
|
||||
end_date = end_date,
|
||||
)
|
||||
if df is None or len(df) < 200:
|
||||
raise ValueError(
|
||||
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
|
||||
)
|
||||
|
||||
# Entraînement dans un thread (opération CPU-bound / GPU-bound)
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, _sync_rl_train, df, request)
|
||||
|
||||
_rl_train_jobs[job_id].update({
|
||||
"status": "completed",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
"n_samples": result.get("n_samples"),
|
||||
"total_timesteps": result.get("total_timesteps"),
|
||||
"n_episodes": result.get("n_episodes"),
|
||||
"mean_ep_return": result.get("mean_ep_return"),
|
||||
"eval_return_pct": result.get("eval_return_pct"),
|
||||
"eval_sharpe": result.get("eval_sharpe"),
|
||||
"eval_win_rate": result.get("eval_win_rate"),
|
||||
"trained_at": result.get("trained_at"),
|
||||
})
|
||||
|
||||
# Auto-attachement à la stratégie RL active si elle existe
|
||||
_attach_rl_model_to_strategy(request)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur entraînement RL job {job_id} : {exc}", exc_info=True)
|
||||
_rl_train_jobs[job_id]["status"] = "failed"
|
||||
_rl_train_jobs[job_id]["error"] = str(exc)
|
||||
|
||||
|
||||
def _sync_rl_train(df, request: RLTrainRequest) -> dict:
|
||||
"""Wrapper synchrone pour RLStrategyModel.train() (exécuté dans un thread)."""
|
||||
from src.ml.rl import RLStrategyModel
|
||||
model = RLStrategyModel(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
total_timesteps = request.total_timesteps,
|
||||
sl_atr_mult = request.sl_atr_mult,
|
||||
tp_atr_mult = request.tp_atr_mult,
|
||||
min_confidence = request.min_confidence,
|
||||
train_ratio = request.train_ratio,
|
||||
initial_capital = request.initial_capital,
|
||||
)
|
||||
return model.train(df)
|
||||
|
||||
|
||||
def _attach_rl_model_to_strategy(request: RLTrainRequest) -> None:
|
||||
"""Attache le modèle RL entraîné à la stratégie rl_driven active (paper trading)."""
|
||||
try:
|
||||
from src.ml.rl import RLStrategyModel
|
||||
from src.strategies.rl_driven import RLDrivenStrategy
|
||||
|
||||
engine = _paper_state.get("engine")
|
||||
if engine and hasattr(engine, 'strategy_engine'):
|
||||
strat = engine.strategy_engine.strategies.get('rl_driven')
|
||||
if strat and isinstance(strat, RLDrivenStrategy):
|
||||
model = RLStrategyModel.load(request.symbol, request.timeframe)
|
||||
strat.attach_model(model)
|
||||
logger.info("Modèle RL attaché à la stratégie rl_driven active")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auto-attach modèle RL ignoré : {e}")
|
||||
|
||||
|
||||
@router.post("/train-rl", response_model=RLTrainResponse,
|
||||
summary="Entraîner le modèle RL (PPO)")
|
||||
async def train_rl_model(request: RLTrainRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Lance l'entraînement de l'agent PPO en arrière-plan.
|
||||
|
||||
L'agent RL (Proximal Policy Optimization) apprend à trader par interaction
|
||||
directe avec un environnement de trading simulé — sans labels supervisés.
|
||||
La récompense est définie par le PnL réalisé et les pénalités de drawdown.
|
||||
|
||||
- Retourne un `job_id` à interroger via `GET /trading/train-rl/{job_id}`
|
||||
- Le modèle est sauvegardé sur disque après entraînement
|
||||
- Si un paper trading rl_driven est actif, le modèle lui est automatiquement attaché
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail="PyTorch requis — rebuilder le container trading-api"
|
||||
)
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
_rl_train_jobs[job_id] = {
|
||||
"status": "pending",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
}
|
||||
|
||||
background_tasks.add_task(_run_rl_train_task, job_id, request)
|
||||
|
||||
return RLTrainResponse(
|
||||
job_id = job_id,
|
||||
status = "pending",
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/train-rl/{job_id}", response_model=RLTrainResponse,
|
||||
summary="Résultat entraînement RL")
|
||||
def get_rl_train_status(job_id: str):
|
||||
"""Retourne l'état d'un job d'entraînement RL (PPO)."""
|
||||
job = _rl_train_jobs.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||
return RLTrainResponse(job_id=job_id, **job)
|
||||
|
||||
|
||||
@router.get("/rl-models", summary="Liste des modèles RL entraînés")
|
||||
def list_rl_models():
|
||||
"""
|
||||
Retourne la liste de tous les modèles RL (PPO) disponibles sur disque,
|
||||
avec leurs métriques (return holdout, Sharpe, date d'entraînement...).
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
return {
|
||||
"error": "PyTorch requis — rebuilder le container trading-api",
|
||||
"models": [],
|
||||
"count": 0,
|
||||
}
|
||||
try:
|
||||
from src.ml.rl import RLStrategyModel
|
||||
models = RLStrategyModel.list_trained_models()
|
||||
return {"models": models, "count": len(models)}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "models": [], "count": 0}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RL STRATEGY — Entraînement et gestion des modèles PPO (Reinforcement Learning)
|
||||
# =============================================================================
|
||||
|
||||
try:
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||
RL_AVAILABLE = True
|
||||
except ImportError:
|
||||
RLStrategyModel = None
|
||||
RL_AVAILABLE = False
|
||||
|
||||
# Stockage en mémoire des jobs d'entraînement RL
|
||||
_rl_train_jobs: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class RLTrainRequest(BaseModel):
|
||||
"""Requête d'entraînement du modèle RL (agent PPO)."""
|
||||
symbol: str = "EURUSD"
|
||||
timeframe: str = "1h"
|
||||
period: str = "2y"
|
||||
total_timesteps: int = 50000
|
||||
|
||||
|
||||
class RLTrainResponse(BaseModel):
|
||||
"""Réponse d'un job d'entraînement RL."""
|
||||
job_id: str
|
||||
status: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
avg_reward: Optional[float] = None
|
||||
sharpe_env: Optional[float] = None
|
||||
total_timesteps: Optional[int] = None
|
||||
trained_at: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
async def _run_rl_train_task(job_id: str, request: RLTrainRequest) -> None:
|
||||
"""Tâche d'entraînement RL exécutée en arrière-plan."""
|
||||
_rl_train_jobs[job_id]["status"] = "running"
|
||||
try:
|
||||
from src.data.data_service import DataService
|
||||
from src.utils.config_loader import ConfigLoader
|
||||
from datetime import timedelta
|
||||
|
||||
config = ConfigLoader.load_all()
|
||||
data_service = DataService(config)
|
||||
|
||||
end_date = datetime.now()
|
||||
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||
unit = request.period[-1]
|
||||
value = int(request.period[:-1])
|
||||
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||
|
||||
df = await data_service.get_historical_data(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
start_date = start_date,
|
||||
end_date = end_date,
|
||||
)
|
||||
if df is None or len(df) < 200:
|
||||
raise ValueError(
|
||||
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
|
||||
)
|
||||
|
||||
# Entraînement dans un thread (opération CPU-bound)
|
||||
import torch
|
||||
torch.set_num_threads(4)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, _sync_rl_train, df, request)
|
||||
|
||||
_rl_train_jobs[job_id].update({
|
||||
"status": "completed",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
"avg_reward": result.get("avg_reward"),
|
||||
"sharpe_env": result.get("sharpe_env"),
|
||||
"total_timesteps": result.get("total_timesteps"),
|
||||
"trained_at": result.get("trained_at"),
|
||||
})
|
||||
|
||||
# Auto-attachement à la stratégie rl_driven active si elle existe
|
||||
_attach_rl_model_to_strategy(request)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur entraînement RL job {job_id} : {exc}", exc_info=True)
|
||||
_rl_train_jobs[job_id]["status"] = "failed"
|
||||
_rl_train_jobs[job_id]["error"] = str(exc)
|
||||
|
||||
|
||||
def _sync_rl_train(df, request: RLTrainRequest) -> dict:
|
||||
"""Wrapper synchrone pour RLStrategyModel.train() (exécuté dans un thread)."""
|
||||
import torch
|
||||
torch.set_num_threads(4)
|
||||
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||
model = RLStrategyModel(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
total_timesteps = request.total_timesteps,
|
||||
)
|
||||
return model.train(df)
|
||||
|
||||
|
||||
def _attach_rl_model_to_strategy(request: RLTrainRequest) -> None:
|
||||
"""Attache le modèle RL entraîné à la stratégie rl_driven active (paper trading)."""
|
||||
try:
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||
from src.strategies.rl_driven import RLDrivenStrategy
|
||||
|
||||
engine = _paper_state.get("engine")
|
||||
if engine and hasattr(engine, 'strategy_engine'):
|
||||
strat = engine.strategy_engine.strategies.get('rl_driven')
|
||||
if strat and isinstance(strat, RLDrivenStrategy):
|
||||
model = RLStrategyModel.load(request.symbol, request.timeframe)
|
||||
strat.attach_model(model)
|
||||
logger.info("Modèle RL attaché à la stratégie rl_driven active")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auto-attach modèle RL ignoré : {e}")
|
||||
|
||||
|
||||
@router.post("/train-rl", response_model=RLTrainResponse,
|
||||
summary="Entraîner le modèle RL (agent PPO)")
|
||||
async def train_rl(request: RLTrainRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Lance l'entraînement de l'agent PPO en arrière-plan.
|
||||
|
||||
L'agent Reinforcement Learning apprend à maximiser le profit cumulatif en
|
||||
interagissant avec un environnement de trading simulé (TradingEnv) — sans
|
||||
labels supervisés. La politique Actor-Critic PPO est optimisée sur
|
||||
`total_timesteps` pas de simulation.
|
||||
|
||||
- Retourne un `job_id` à interroger via `GET /trading/train-rl/{job_id}`
|
||||
- Le modèle est sauvegardé sur disque après entraînement
|
||||
- Si un paper trading rl_driven est actif, le modèle lui est automatiquement attaché
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail="PyTorch / stable-baselines3 requis — rebuilder le container trading-api"
|
||||
)
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
_rl_train_jobs[job_id] = {
|
||||
"status": "pending",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
}
|
||||
|
||||
background_tasks.add_task(_run_rl_train_task, job_id, request)
|
||||
|
||||
return RLTrainResponse(
|
||||
job_id = job_id,
|
||||
status = "pending",
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/train-rl/{job_id}", response_model=RLTrainResponse,
|
||||
summary="Résultat entraînement RL")
|
||||
def get_rl_train_status(job_id: str):
|
||||
"""Retourne l'état d'un job d'entraînement RL (PPO)."""
|
||||
job = _rl_train_jobs.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||
return RLTrainResponse(job_id=job_id, **job)
|
||||
|
||||
|
||||
@router.get("/rl-models", summary="Liste des modèles RL entraînés")
|
||||
def list_rl_models():
|
||||
"""
|
||||
Retourne la liste de tous les modèles RL disponibles sur disque,
|
||||
avec leurs métriques (avg_reward, sharpe_env, date d'entraînement...).
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
return {
|
||||
"error": "PyTorch / stable-baselines3 requis — rebuilder le container trading-api",
|
||||
"models": [],
|
||||
"count": 0,
|
||||
}
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||
models = RLStrategyModel.list_trained_models()
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
@@ -4,18 +4,17 @@ CandlestickImageRenderer — Convertit des données OHLCV en images de graphique
|
||||
Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB
|
||||
qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns.
|
||||
|
||||
Rendu :
|
||||
Rendu (pur numpy, très rapide) :
|
||||
- Fond noir (#0d1117), style TradingView
|
||||
- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse
|
||||
- Volume en bas de l'image (via mplfinance)
|
||||
- Mèches (high/low), corps (open/close), volume en bas (20% de hauteur)
|
||||
- Pas d'axes, pas de labels, pas de titre
|
||||
- Taille fixe : 128×128 pixels, 3 canaux RGB
|
||||
|
||||
Si mplfinance ou PIL ne sont pas disponibles, un rendu de fallback basique
|
||||
encode les données OHLCV numériquement sous forme d'image 2D.
|
||||
Perf : ~5 000 images/s (vs ~5/s avec mplfinance).
|
||||
mplfinance reste utilisable via _render_with_mplfinance() pour l'affichage.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -24,28 +23,30 @@ import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Détection optionnelle de mplfinance et PIL ---
|
||||
# Couleurs TradingView (normalisées [0, 1])
|
||||
BG_COLOR = np.array([13 / 255, 17 / 255, 23 / 255], dtype=np.float32) # #0d1117
|
||||
GREEN_COLOR = np.array([38 / 255, 166 / 255, 154 / 255], dtype=np.float32) # #26a69a
|
||||
RED_COLOR = np.array([239 / 255, 83 / 255, 80 / 255], dtype=np.float32) # #ef5350
|
||||
|
||||
IMAGE_SIZE = 128
|
||||
VOLUME_RATIO = 0.20 # 20% de la hauteur pour le volume
|
||||
|
||||
# --- Détection optionnelle de mplfinance (pour affichage uniquement) ---
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Backend non-interactif (pas d'écran requis)
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
MPLFINANCE_AVAILABLE = True
|
||||
logger.debug("mplfinance disponible — rendu haute qualité activé")
|
||||
except ImportError:
|
||||
MPLFINANCE_AVAILABLE = False
|
||||
logger.warning("mplfinance non disponible — utilisation du rendu fallback")
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
PIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIL_AVAILABLE = False
|
||||
logger.warning("Pillow (PIL) non disponible — rendu fallback uniquement")
|
||||
|
||||
|
||||
# Taille cible des images en pixels
|
||||
IMAGE_SIZE = 128
|
||||
|
||||
|
||||
class CandlestickImageRenderer:
|
||||
@@ -53,15 +54,17 @@ class CandlestickImageRenderer:
|
||||
Convertit des données OHLCV en images de graphiques en chandeliers.
|
||||
|
||||
Chaque image est un instantané visuel de `seq_len` bougies consécutives,
|
||||
rendu avec mplfinance (style fond noir, bougies colorées, volume en bas).
|
||||
rendu via pur numpy (rapide, ~5 000 images/s) ou mplfinance (lent, haute qualité).
|
||||
Le résultat est normalisé en float32 dans [0, 1].
|
||||
|
||||
Args:
|
||||
image_size: Taille carrée de l'image en pixels (défaut 128)
|
||||
use_mplfinance: Forcer mplfinance même pour encode() (lent, déconseillé)
|
||||
"""
|
||||
|
||||
def __init__(self, image_size: int = IMAGE_SIZE):
|
||||
self.image_size = image_size
|
||||
def __init__(self, image_size: int = IMAGE_SIZE, use_mplfinance: bool = False):
|
||||
self.image_size = image_size
|
||||
self.use_mplfinance = use_mplfinance and MPLFINANCE_AVAILABLE and PIL_AVAILABLE
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface publique
|
||||
@@ -74,6 +77,8 @@ class CandlestickImageRenderer:
|
||||
Produit N = len(df) - seq_len fenêtres, chacune rendue en image
|
||||
128×128 RGB normalisée [0, 1].
|
||||
|
||||
Utilise le renderer pur numpy (rapide) par défaut.
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV avec colonnes open/high/low/close/volume
|
||||
seq_len: Nombre de bougies par fenêtre (défaut 64)
|
||||
@@ -91,20 +96,13 @@ class CandlestickImageRenderer:
|
||||
)
|
||||
return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32)
|
||||
|
||||
images = np.zeros(
|
||||
(n_windows, 3, self.image_size, self.image_size), dtype=np.float32
|
||||
)
|
||||
# Extraction des valeurs OHLCV en numpy (une seule conversion)
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
|
||||
|
||||
for i in range(n_windows):
|
||||
window = df.iloc[i: i + seq_len]
|
||||
try:
|
||||
img = self._render_single(window)
|
||||
images[i] = img
|
||||
except Exception as e:
|
||||
# Fenêtre problématique → on laisse des zéros pour cette position
|
||||
logger.debug(f"Erreur rendu fenêtre {i} : {e}")
|
||||
|
||||
return images
|
||||
if self.use_mplfinance:
|
||||
return self._encode_mplfinance(df, seq_len, n_windows)
|
||||
else:
|
||||
return self._encode_numpy(ohlcv, seq_len, n_windows)
|
||||
|
||||
def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||
"""
|
||||
@@ -130,157 +128,319 @@ class CandlestickImageRenderer:
|
||||
)
|
||||
|
||||
window = df.iloc[-seq_len:]
|
||||
img = self._render_single(window)
|
||||
# Ajouter dimension batch
|
||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
|
||||
img = self._render_numpy(ohlcv)
|
||||
return img[np.newaxis, ...] # (1, 3, H, W)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Rendu d'une fenêtre
|
||||
# Renderer pur numpy (rapide)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _render_single(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
def _encode_numpy(
|
||||
self,
|
||||
ohlcv: np.ndarray,
|
||||
seq_len: int,
|
||||
n_windows: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rend une fenêtre OHLCV en image numpy (3, 128, 128).
|
||||
Encode toutes les fenêtres en batch vectorisé (sans boucle Python sur N).
|
||||
|
||||
Utilise mplfinance si disponible, sinon fallback encodage numérique.
|
||||
Utilise sliding_window_view pour créer toutes les fenêtres d'un coup,
|
||||
puis boucle sur les seq_len bougies (64) en traitant toutes les N fenêtres
|
||||
simultanément. Numpy libère le GIL pendant les opérations vectorisées,
|
||||
ce qui préserve la réactivité de l'event loop FastAPI.
|
||||
|
||||
Args:
|
||||
df_window: DataFrame OHLCV pour une fenêtre (seq_len barres)
|
||||
ohlcv: (T, 5) float32 [open, high, low, close, volume]
|
||||
seq_len: Longueur de chaque fenêtre
|
||||
n_windows: Nombre total de fenêtres à générer
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (3, 128, 128), float32, valeurs [0, 1]
|
||||
(N, 3, H, W) float32
|
||||
"""
|
||||
if MPLFINANCE_AVAILABLE and PIL_AVAILABLE:
|
||||
return self._render_with_mplfinance(df_window)
|
||||
else:
|
||||
return self._render_fallback(df_window)
|
||||
H = W = self.image_size
|
||||
|
||||
# --- Toutes les fenêtres d'un coup : (N, seq_len, 5) ---
|
||||
# sliding_window_view est O(1) en mémoire (vue sans copie)
|
||||
windows = np.lib.stride_tricks.sliding_window_view(ohlcv, (seq_len, 5))
|
||||
windows = windows[:n_windows, 0, :, :].astype(np.float32) # (N, seq_len, 5)
|
||||
|
||||
opens = windows[:, :, 0] # (N, seq_len)
|
||||
highs = windows[:, :, 1]
|
||||
lows = windows[:, :, 2]
|
||||
closes = windows[:, :, 3]
|
||||
vols = windows[:, :, 4]
|
||||
|
||||
# --- Fond de toutes les images ---
|
||||
images = np.empty((n_windows, 3, H, W), dtype=np.float32)
|
||||
images[:, 0, :, :] = BG_COLOR[0]
|
||||
images[:, 1, :, :] = BG_COLOR[1]
|
||||
images[:, 2, :, :] = BG_COLOR[2]
|
||||
|
||||
# --- Normalisation des prix par fenêtre ---
|
||||
price_min = lows.min(axis=1, keepdims=True) # (N, 1)
|
||||
price_max = highs.max(axis=1, keepdims=True) # (N, 1)
|
||||
price_rng = np.maximum(price_max - price_min, 1e-8)
|
||||
|
||||
vol_h = max(1, int(H * VOLUME_RATIO))
|
||||
chart_h = H - vol_h
|
||||
|
||||
def to_row(prices: np.ndarray) -> np.ndarray:
|
||||
"""Convertit prix → rangée pixel (0 = haut, chart_h-1 = bas)."""
|
||||
norm = (prices - price_min) / price_rng
|
||||
rows = ((1.0 - norm) * (chart_h - 1)).astype(np.int32)
|
||||
return np.clip(rows, 0, chart_h - 1)
|
||||
|
||||
row_h = to_row(highs) # (N, seq_len)
|
||||
row_l = to_row(lows)
|
||||
row_o = to_row(opens)
|
||||
row_c = to_row(closes)
|
||||
|
||||
body_top = np.minimum(row_o, row_c) # (N, seq_len)
|
||||
body_bot = np.maximum(row_o, row_c)
|
||||
|
||||
is_bull = (closes >= opens) # (N, seq_len) bool
|
||||
|
||||
# Volume normalisé [0, 1] par fenêtre
|
||||
vol_max = np.maximum(vols.max(axis=1, keepdims=True), 1e-8)
|
||||
vol_norm = vols / vol_max # (N, seq_len)
|
||||
|
||||
# Positions X des bougies
|
||||
candle_w = max(1, W // seq_len)
|
||||
half_w = max(1, candle_w // 2)
|
||||
x_centers = ((np.arange(seq_len) + 0.5) * W / seq_len).astype(np.int32)
|
||||
|
||||
row_idx = np.arange(chart_h) # (chart_h,) pour les masques
|
||||
vol_idx = np.arange(vol_h) # (vol_h,)
|
||||
|
||||
# --- Boucle sur les bougies (64 itérations, GIL libéré entre chaque) ---
|
||||
for i in range(seq_len):
|
||||
x_c = int(x_centers[i])
|
||||
x0 = max(0, x_c - half_w)
|
||||
x1 = min(W, x_c + half_w + 1)
|
||||
wick_x = min(x_c, W - 1)
|
||||
|
||||
rh = row_h[:, i] # (N,)
|
||||
rl = row_l[:, i]
|
||||
bt = body_top[:, i]
|
||||
bb = body_bot[:, i]
|
||||
bull = is_bull[:, i] # (N,) bool
|
||||
|
||||
# Masques vectorisés sur toutes les N fenêtres
|
||||
# wick_mask[w, r] = True si rh[w] <= r <= rl[w]
|
||||
wick_mask = (
|
||||
(row_idx[None, :] >= rh[:, None]) &
|
||||
(row_idx[None, :] <= rl[:, None])
|
||||
) # (N, chart_h)
|
||||
|
||||
body_mask = (
|
||||
(row_idx[None, :] >= bt[:, None]) &
|
||||
(row_idx[None, :] <= bb[:, None])
|
||||
) # (N, chart_h)
|
||||
|
||||
# Volume : barre du bas
|
||||
vol_bar_h = (vol_norm[:, i] * vol_h).astype(np.int32)
|
||||
vol_thresh = np.maximum(vol_h - vol_bar_h, 0)
|
||||
vol_mask = vol_idx[None, :] >= vol_thresh[:, None] # (N, vol_h)
|
||||
|
||||
for c in range(3):
|
||||
g = float(GREEN_COLOR[c])
|
||||
r = float(RED_COLOR[c])
|
||||
# Couleur par fenêtre : vert si haussière, rouge sinon
|
||||
color_w = np.where(bull, g, r).astype(np.float32) # (N,)
|
||||
|
||||
# Mèche
|
||||
images[:, c, :chart_h, wick_x] = np.where(
|
||||
wick_mask,
|
||||
color_w[:, None],
|
||||
images[:, c, :chart_h, wick_x],
|
||||
)
|
||||
|
||||
# Corps
|
||||
images[:, c, :chart_h, x0:x1] = np.where(
|
||||
body_mask[:, :, None],
|
||||
color_w[:, None, None],
|
||||
images[:, c, :chart_h, x0:x1],
|
||||
)
|
||||
|
||||
# Volume (opacité 60%)
|
||||
images[:, c, H - vol_h:H, x0:x1] = np.where(
|
||||
vol_mask[:, :, None],
|
||||
color_w[:, None, None] * 0.6,
|
||||
images[:, c, H - vol_h:H, x0:x1],
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
def _render_numpy(self, ohlcv: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rend une seule fenêtre OHLCV en image (3, H, W) via pur numpy.
|
||||
|
||||
Dessin :
|
||||
- Fond #0d1117
|
||||
- Mèche : ligne verticale high→low (1 pixel de large)
|
||||
- Corps : rectangle open→close (largeur proportionnelle)
|
||||
- Volume : barre en bas (20% hauteur), opacité 60%
|
||||
|
||||
Args:
|
||||
ohlcv: (N, 5) float32 [open, high, low, close, volume]
|
||||
|
||||
Returns:
|
||||
(3, image_size, image_size) float32 [0, 1]
|
||||
"""
|
||||
H = W = self.image_size
|
||||
n = len(ohlcv)
|
||||
|
||||
if n == 0:
|
||||
return np.zeros((3, H, W), dtype=np.float32)
|
||||
|
||||
# --- Fond ---
|
||||
img = np.empty((3, H, W), dtype=np.float32)
|
||||
for c in range(3):
|
||||
img[c] = BG_COLOR[c]
|
||||
|
||||
opens = ohlcv[:, 0]
|
||||
highs = ohlcv[:, 1]
|
||||
lows = ohlcv[:, 2]
|
||||
closes = ohlcv[:, 3]
|
||||
vols = ohlcv[:, 4]
|
||||
|
||||
# --- Normalisation des prix ---
|
||||
price_min = lows.min()
|
||||
price_max = highs.max()
|
||||
if price_max <= price_min:
|
||||
return img # données dégénérées
|
||||
|
||||
vol_h = max(1, int(H * VOLUME_RATIO))
|
||||
chart_h = H - vol_h # hauteur zone prix
|
||||
|
||||
def price_to_row(price: float) -> int:
|
||||
"""Convertit un prix en ligne pixel (0 = haut, chart_h-1 = bas)."""
|
||||
norm = (price - price_min) / (price_max - price_min)
|
||||
row = int((1.0 - norm) * (chart_h - 1))
|
||||
return max(0, min(chart_h - 1, row))
|
||||
|
||||
# --- Largeur des bougies ---
|
||||
candle_w = max(1, W // n)
|
||||
half_w = max(1, candle_w // 2)
|
||||
|
||||
# --- Volume ---
|
||||
vol_max = vols.max()
|
||||
|
||||
for i in range(n):
|
||||
x_c = int((i + 0.5) * W / n)
|
||||
x0 = max(0, x_c - half_w)
|
||||
x1 = min(W, x_c + half_w + 1)
|
||||
|
||||
is_bull = closes[i] >= opens[i]
|
||||
color = GREEN_COLOR if is_bull else RED_COLOR
|
||||
|
||||
row_h = price_to_row(highs[i])
|
||||
row_l = price_to_row(lows[i])
|
||||
row_o = price_to_row(opens[i])
|
||||
row_c = price_to_row(closes[i])
|
||||
|
||||
body_top = min(row_o, row_c)
|
||||
body_bot = max(row_o, row_c)
|
||||
|
||||
# Mèche (1 pixel de large, centré)
|
||||
wick_x = min(x_c, W - 1)
|
||||
for c in range(3):
|
||||
img[c, row_h: row_l + 1, wick_x] = color[c]
|
||||
|
||||
# Corps
|
||||
if body_bot >= body_top:
|
||||
for c in range(3):
|
||||
img[c, body_top: body_bot + 1, x0: x1] = color[c]
|
||||
|
||||
# Volume (bas de l'image)
|
||||
if vol_max > 0:
|
||||
bar_h = int((vols[i] / vol_max) * vol_h)
|
||||
if bar_h > 0:
|
||||
row_v0 = H - bar_h
|
||||
for c in range(3):
|
||||
img[c, row_v0: H, x0: x1] = color[c] * 0.6
|
||||
|
||||
return img
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Renderer mplfinance (lent, haute qualité, pour affichage uniquement)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _encode_mplfinance(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
seq_len: int,
|
||||
n_windows: int,
|
||||
) -> np.ndarray:
|
||||
"""Encode via mplfinance (lent — ne pas utiliser pour l'entraînement)."""
|
||||
H = W = self.image_size
|
||||
images = np.zeros((n_windows, 3, H, W), dtype=np.float32)
|
||||
import warnings
|
||||
for i in range(n_windows):
|
||||
window = df.iloc[i: i + seq_len]
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
images[i] = self._render_with_mplfinance(window)
|
||||
except Exception as e:
|
||||
logger.debug(f"Erreur mplfinance fenêtre {i} : {e}")
|
||||
return images
|
||||
|
||||
def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Rendu haute qualité via mplfinance.
|
||||
Rendu haute qualité via mplfinance (pour affichage / debug).
|
||||
|
||||
Style : fond noir (#0d1117), bougies vertes/rouges, volume en bas,
|
||||
aucun axe ni label ni titre. Image 128×128 RGB.
|
||||
Raises:
|
||||
ImportError si mplfinance ou PIL non disponible.
|
||||
"""
|
||||
# Style personnalisé : fond noir, bougies colorées TradingView
|
||||
if not MPLFINANCE_AVAILABLE or not PIL_AVAILABLE:
|
||||
raise ImportError("mplfinance ou PIL non disponible")
|
||||
|
||||
style = mpf.make_mpf_style(
|
||||
base_mpf_style='nightclouds',
|
||||
marketcolors=mpf.make_marketcolors(
|
||||
up='#26a69a', # vert TradingView (hausse)
|
||||
down='#ef5350', # rouge TradingView (baisse)
|
||||
up='#26a69a', down='#ef5350',
|
||||
wick={'up': '#26a69a', 'down': '#ef5350'},
|
||||
edge={'up': '#26a69a', 'down': '#ef5350'},
|
||||
volume={'up': '#26a69a', 'down': '#ef5350'},
|
||||
),
|
||||
facecolor='#0d1117', # Fond noir GitHub-style
|
||||
facecolor='#0d1117',
|
||||
figcolor='#0d1117',
|
||||
gridcolor='#0d1117',
|
||||
)
|
||||
|
||||
# Rendu en mémoire via BytesIO
|
||||
buf = io.BytesIO()
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
fig, axes = mpf.plot(
|
||||
df_window,
|
||||
type='candle',
|
||||
style=style,
|
||||
volume=True,
|
||||
axisoff=True,
|
||||
tight_layout=True,
|
||||
returnfig=True,
|
||||
figsize=(1.28, 1.28),
|
||||
)
|
||||
|
||||
fig, axes = mpf.plot(
|
||||
df_window,
|
||||
type='candle',
|
||||
style=style,
|
||||
volume=True,
|
||||
axisoff=True, # Pas d'axes
|
||||
tight_layout=True,
|
||||
returnfig=True,
|
||||
figsize=(1.28, 1.28), # 128 DPI × 1.28 inch = 128 pixels
|
||||
)
|
||||
|
||||
# Suppression de tous les éléments décoratifs
|
||||
for ax in axes:
|
||||
ax.set_axis_off()
|
||||
ax.set_facecolor('#0d1117')
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
fig.savefig(
|
||||
buf,
|
||||
format='png',
|
||||
dpi=100,
|
||||
bbox_inches='tight',
|
||||
pad_inches=0,
|
||||
facecolor='#0d1117',
|
||||
)
|
||||
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight',
|
||||
pad_inches=0, facecolor='#0d1117')
|
||||
plt.close(fig)
|
||||
|
||||
buf.seek(0)
|
||||
pil_img = Image.open(buf).convert('RGB')
|
||||
pil_img = pil_img.resize(
|
||||
(self.image_size, self.image_size), Image.LANCZOS
|
||||
)
|
||||
pil_img = pil_img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
|
||||
# Conversion en array numpy (H, W, 3) → (3, H, W), float32, [0,1]
|
||||
arr = np.array(pil_img, dtype=np.float32) / 255.0
|
||||
arr = arr.transpose(2, 0, 1) # HWC → CHW
|
||||
|
||||
return arr
|
||||
|
||||
def _render_fallback(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Rendu de secours sans mplfinance : encode les données OHLCV
|
||||
directement en image 2D normalisée.
|
||||
|
||||
Chaque colonne de l'image correspond à une bougie :
|
||||
- Canal 0 (R) : position relative du close dans le range [low, high]
|
||||
- Canal 1 (G) : amplitude de la bougie (high - low normalisé)
|
||||
- Canal 2 (B) : volume normalisé
|
||||
|
||||
L'image est ensuite redimensionnée à image_size × image_size.
|
||||
"""
|
||||
cols = df_window[['open', 'high', 'low', 'close', 'volume']].copy()
|
||||
n = len(cols)
|
||||
|
||||
if n == 0:
|
||||
return np.zeros((3, self.image_size, self.image_size), dtype=np.float32)
|
||||
|
||||
# Normalisation par colonne
|
||||
highs = cols['high'].values.astype(np.float32)
|
||||
lows = cols['low'].values.astype(np.float32)
|
||||
closes = cols['close'].values.astype(np.float32)
|
||||
opens = cols['open'].values.astype(np.float32)
|
||||
vols = cols['volume'].values.astype(np.float32)
|
||||
|
||||
price_range = highs - lows
|
||||
price_range = np.where(price_range == 0, 1e-8, price_range)
|
||||
|
||||
# Canal R : position du close dans le range de la bougie [0, 1]
|
||||
close_pos = (closes - lows) / price_range
|
||||
|
||||
# Canal G : corps de la bougie (|close - open| / range)
|
||||
body = np.abs(closes - opens) / price_range
|
||||
|
||||
# Canal B : volume normalisé [0, 1]
|
||||
vol_max = vols.max()
|
||||
vol_norm = vols / (vol_max if vol_max > 0 else 1.0)
|
||||
|
||||
# Construction image : 3 × 1 × n → 3 × image_size × image_size
|
||||
# On crée une image de height=image_size, width=n puis on redimensionne
|
||||
img = np.zeros((3, 1, n), dtype=np.float32)
|
||||
img[0, 0, :] = close_pos
|
||||
img[1, 0, :] = body
|
||||
img[2, 0, :] = vol_norm
|
||||
|
||||
# Redimensionnement vers (3, image_size, image_size) via répétition
|
||||
# On étire chaque canal sur les deux dimensions
|
||||
img_resized = np.zeros(
|
||||
(3, self.image_size, self.image_size), dtype=np.float32
|
||||
)
|
||||
for c in range(3):
|
||||
# Répétition en hauteur (axe 0 du canal) et interpolation en largeur
|
||||
channel = img[c, 0, :] # (n,)
|
||||
# Interpolation 1D vers image_size
|
||||
x_orig = np.linspace(0, 1, n)
|
||||
x_new = np.linspace(0, 1, self.image_size)
|
||||
channel_resized = np.interp(x_new, x_orig, channel).astype(np.float32)
|
||||
# Étirer sur toute la hauteur
|
||||
img_resized[c] = np.tile(channel_resized, (self.image_size, 1))
|
||||
|
||||
return img_resized
|
||||
return arr.transpose(2, 0, 1) # HWC → CHW
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
@@ -290,27 +450,18 @@ class CandlestickImageRenderer:
|
||||
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex.
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV brut
|
||||
|
||||
Returns:
|
||||
DataFrame nettoyé avec colonnes open/high/low/close/volume
|
||||
"""
|
||||
df = df.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
# S'assurer que l'index est un DatetimeIndex (requis par mplfinance)
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
try:
|
||||
df.index = pd.to_datetime(df.index)
|
||||
except Exception:
|
||||
# Créer un index artificiel si la conversion échoue
|
||||
df.index = pd.date_range(
|
||||
start='2020-01-01', periods=len(df), freq='1h'
|
||||
)
|
||||
|
||||
# Conserver uniquement les colonnes OHLCV
|
||||
required = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in required:
|
||||
if col not in df.columns:
|
||||
@@ -318,5 +469,4 @@ class CandlestickImageRenderer:
|
||||
|
||||
df = df[required].dropna(subset=['open', 'high', 'low', 'close'])
|
||||
df = df.ffill().bfill()
|
||||
|
||||
return df
|
||||
|
||||
@@ -480,6 +480,9 @@ class CNNImageStrategyModel:
|
||||
y: Labels encodés (N,) int, valeurs {0, 1, 2}
|
||||
class_weights: Poids de classes (3,) pour CrossEntropyLoss
|
||||
"""
|
||||
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
|
||||
torch.set_num_threads(4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
class_weights = class_weights.to(device)
|
||||
|
||||
@@ -11,10 +11,12 @@ les différents régimes de marché:
|
||||
Permet d'adapter les stratégies selon le régime actuel.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
try:
|
||||
@@ -24,8 +26,18 @@ except ImportError:
|
||||
HMMLEARN_AVAILABLE = False
|
||||
logging.warning("hmmlearn not installed. Install with: pip install hmmlearn")
|
||||
|
||||
try:
|
||||
import joblib
|
||||
JOBLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
JOBLIB_AVAILABLE = False
|
||||
logging.warning("joblib not installed. La persistance HMM sera désactivée.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de persistance des modèles HMM
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent / "models" / "hmm"
|
||||
|
||||
|
||||
class RegimeDetector:
|
||||
"""
|
||||
@@ -79,7 +91,10 @@ class RegimeDetector:
|
||||
|
||||
self.is_fitted = False
|
||||
self.feature_names = []
|
||||
|
||||
# Métadonnées d'entraînement pour la persistance
|
||||
self._trained_at: Optional[datetime] = None
|
||||
self._n_samples: int = 0
|
||||
|
||||
logger.info(f"RegimeDetector initialized with {n_regimes} regimes")
|
||||
|
||||
def fit(self, data: pd.DataFrame, features: Optional[List[str]] = None):
|
||||
@@ -115,11 +130,147 @@ class RegimeDetector:
|
||||
try:
|
||||
self.model.fit(X)
|
||||
self.is_fitted = True
|
||||
logger.info("✅ HMM model fitted successfully")
|
||||
self._trained_at = datetime.now()
|
||||
self._n_samples = len(X)
|
||||
logger.info("Modèle HMM entraîné avec succès")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fitting HMM: {e}")
|
||||
logger.error(f"Erreur lors de l'entraînement HMM : {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def is_trained(self) -> bool:
|
||||
"""True si le modèle HMM a été ajusté (fit)."""
|
||||
return self.is_fitted
|
||||
|
||||
def needs_retrain(self, max_age_hours: int = 24) -> bool:
|
||||
"""
|
||||
Indique si le modèle doit être ré-entraîné.
|
||||
|
||||
Un ré-entraînement est nécessaire si :
|
||||
- Le modèle n'a jamais été entraîné
|
||||
- La date d'entraînement est inconnue
|
||||
- Le modèle est plus vieux que max_age_hours
|
||||
|
||||
Args:
|
||||
max_age_hours: Âge maximum du modèle en heures (défaut 24h)
|
||||
|
||||
Returns:
|
||||
True si un ré-entraînement est nécessaire
|
||||
"""
|
||||
if not self.is_fitted or self._trained_at is None:
|
||||
return True
|
||||
age = datetime.now() - self._trained_at
|
||||
return age > timedelta(hours=max_age_hours)
|
||||
|
||||
def save(self, symbol: str, timeframe: str) -> bool:
|
||||
"""
|
||||
Sauvegarde le modèle HMM entraîné sur disque avec joblib.
|
||||
|
||||
Le modèle est sauvegardé dans :
|
||||
models/hmm/{symbol}_{timeframe}.joblib
|
||||
|
||||
Les métadonnées (date, n_samples, n_components, labels) sont
|
||||
stockées dans un fichier JSON compagnon :
|
||||
models/hmm/{symbol}_{timeframe}_meta.json
|
||||
|
||||
Args:
|
||||
symbol: Symbole de l'instrument (ex : "EURUSD")
|
||||
timeframe: Unité de temps (ex : "1h")
|
||||
|
||||
Returns:
|
||||
True si la sauvegarde a réussi, False sinon
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
logger.warning("Impossible de sauvegarder : modèle non entraîné")
|
||||
return False
|
||||
|
||||
if not JOBLIB_AVAILABLE:
|
||||
logger.warning("joblib indisponible — sauvegarde HMM ignorée")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Créer le répertoire si nécessaire
|
||||
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
base = f"{symbol}_{timeframe}"
|
||||
model_path = MODELS_DIR / f"{base}.joblib"
|
||||
meta_path = MODELS_DIR / f"{base}_meta.json"
|
||||
|
||||
# Sauvegarder le modèle HMM + feature_names
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"feature_names": self.feature_names,
|
||||
"n_regimes": self.n_regimes,
|
||||
"random_state": self.random_state,
|
||||
}
|
||||
joblib.dump(payload, model_path)
|
||||
|
||||
# Sauvegarder les métadonnées en JSON
|
||||
meta = {
|
||||
"trained_at": self._trained_at.isoformat() if self._trained_at else None,
|
||||
"n_samples": self._n_samples,
|
||||
"n_components": self.n_regimes,
|
||||
"regime_labels": self.REGIME_NAMES,
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
}
|
||||
meta_path.write_text(json.dumps(meta, indent=2, default=str))
|
||||
|
||||
logger.info(f"Modèle HMM sauvegardé : {model_path}")
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur lors de la sauvegarde du modèle HMM : {exc}")
|
||||
return False
|
||||
|
||||
def load(self, symbol: str, timeframe: str) -> bool:
|
||||
"""
|
||||
Charge un modèle HMM depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Symbole de l'instrument (ex : "EURUSD")
|
||||
timeframe: Unité de temps (ex : "1h")
|
||||
|
||||
Returns:
|
||||
True si le chargement a réussi, False sinon
|
||||
"""
|
||||
if not JOBLIB_AVAILABLE:
|
||||
logger.warning("joblib indisponible — chargement HMM impossible")
|
||||
return False
|
||||
|
||||
base = f"{symbol}_{timeframe}"
|
||||
model_path = MODELS_DIR / f"{base}.joblib"
|
||||
meta_path = MODELS_DIR / f"{base}_meta.json"
|
||||
|
||||
if not model_path.exists():
|
||||
logger.debug(f"Aucun modèle HMM trouvé pour {symbol}/{timeframe}")
|
||||
return False
|
||||
|
||||
try:
|
||||
payload = joblib.load(model_path)
|
||||
self.model = payload["model"]
|
||||
self.feature_names = payload["feature_names"]
|
||||
self.n_regimes = payload["n_regimes"]
|
||||
self.random_state = payload["random_state"]
|
||||
self.is_fitted = True
|
||||
|
||||
# Charger les métadonnées si disponibles
|
||||
if meta_path.exists():
|
||||
meta = json.loads(meta_path.read_text())
|
||||
trained_at_raw = meta.get("trained_at")
|
||||
self._trained_at = (
|
||||
datetime.fromisoformat(trained_at_raw) if trained_at_raw else None
|
||||
)
|
||||
self._n_samples = meta.get("n_samples", 0)
|
||||
|
||||
logger.info(f"Modèle HMM chargé depuis {model_path} (entraîné le {self._trained_at})")
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur lors du chargement du modèle HMM : {exc}")
|
||||
self.is_fitted = False
|
||||
return False
|
||||
|
||||
def predict_regime(self, data: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Prédit les régimes pour toutes les barres.
|
||||
|
||||
24
src/ml/rl/__init__.py
Normal file
24
src/ml/rl/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Module RL (Reinforcement Learning) — Agent PPO pour le trading algorithmique.
|
||||
|
||||
Ce module implémente un agent PPO (Proximal Policy Optimization) entraîné
|
||||
par renforcement sur un environnement de trading simulé.
|
||||
|
||||
Composants :
|
||||
TradingEnv — Environnement gymnasium conforme (observation 20 features)
|
||||
PPOModel — Réseau Actor-Critic MLP 256→128→64 + entraînement PPO
|
||||
RLStrategyModel — Interface unifiée (identique à MLStrategyModel)
|
||||
"""
|
||||
|
||||
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE
|
||||
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel, RL_AVAILABLE
|
||||
|
||||
__all__ = [
|
||||
'TradingEnv',
|
||||
'PPOModel',
|
||||
'RLStrategyModel',
|
||||
'RL_AVAILABLE',
|
||||
'TORCH_AVAILABLE',
|
||||
'GYM_AVAILABLE',
|
||||
]
|
||||
681
src/ml/rl/ppo_model.py
Normal file
681
src/ml/rl/ppo_model.py
Normal file
@@ -0,0 +1,681 @@
|
||||
"""
|
||||
PPOModel — Implémentation manuelle de l'algorithme PPO (Proximal Policy Optimization).
|
||||
|
||||
Architecture Actor-Critic :
|
||||
- Réseau partagé : MLP 3 couches (256 → 128 → 64) avec BatchNorm + ReLU
|
||||
- Tête Actor : couche linéaire → logits (n_actions=3)
|
||||
- Tête Critic : couche linéaire → valeur scalaire V(s)
|
||||
|
||||
Hyperparamètres PPO :
|
||||
- clip ε=0.2 — clip ratio de probabilité pour éviter les grandes mises à jour
|
||||
- entropy_coef=0.01 — bonus d'entropie pour exploration
|
||||
- value_coef=0.5 — coefficient de la loss valeur
|
||||
- n_steps=2048 — nombre de transitions collectées avant mise à jour
|
||||
- n_epochs=10 — passes sur les données collectées
|
||||
- batch_size=64 — taille des mini-batchs
|
||||
- gamma=0.99 — facteur d'actualisation
|
||||
- gae_lambda=0.95 — λ pour Generalized Advantage Estimation
|
||||
|
||||
Sauvegarde :
|
||||
- models/rl_strategy/{symbol}_{timeframe}.pt (state_dict + config)
|
||||
- models/rl_strategy/{symbol}_{timeframe}_meta.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de sauvegarde des modèles RL
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "rl_strategy"
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Import conditionnel PyTorch
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.distributions import Categorical
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
optim = None
|
||||
Categorical = None
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Réseau Actor-Critic
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
class ActorCriticNetwork(nn.Module):
|
||||
"""
|
||||
Réseau Actor-Critic partagé pour PPO.
|
||||
|
||||
Architecture :
|
||||
- Tronc commun : Linear(obs_dim → 256) → BN → ReLU
|
||||
Linear(256 → 128) → BN → ReLU
|
||||
Linear(128 → 64) → ReLU
|
||||
- Tête Actor : Linear(64 → n_actions)
|
||||
- Tête Critic : Linear(64 → 1)
|
||||
|
||||
Args:
|
||||
obs_dim: Dimension de l'espace d'observation
|
||||
n_actions: Nombre d'actions discrètes
|
||||
"""
|
||||
|
||||
def __init__(self, obs_dim: int = 20, n_actions: int = 3):
|
||||
super().__init__()
|
||||
|
||||
# Tronc commun
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(obs_dim, 256),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Tête Actor : logits pour chaque action
|
||||
self.actor_head = nn.Linear(64, n_actions)
|
||||
|
||||
# Tête Critic : estimation de la valeur d'état V(s)
|
||||
self.critic_head = nn.Linear(64, 1)
|
||||
|
||||
# Initialisation des poids (orthogonale pour stabilité PPO)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialisation orthogonale recommandée pour les réseaux PPO."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
# Gain plus faible pour les têtes finales
|
||||
nn.init.orthogonal_(self.actor_head.weight, gain=0.01)
|
||||
nn.init.orthogonal_(self.critic_head.weight, gain=1.0)
|
||||
|
||||
def forward(
|
||||
self, x: 'torch.Tensor'
|
||||
) -> Tuple['torch.Tensor', 'torch.Tensor']:
|
||||
"""
|
||||
Calcule les logits actor et la valeur critic.
|
||||
|
||||
Args:
|
||||
x: Tensor d'observations (batch_size, obs_dim)
|
||||
|
||||
Returns:
|
||||
Tuple (logits_actor, value_critic)
|
||||
logits_actor : (batch_size, n_actions)
|
||||
value_critic : (batch_size, 1)
|
||||
"""
|
||||
features = self.trunk(x)
|
||||
logits = self.actor_head(features)
|
||||
value = self.critic_head(features)
|
||||
return logits, value
|
||||
|
||||
def get_action_and_value(
|
||||
self, x: 'torch.Tensor', action: Optional['torch.Tensor'] = None
|
||||
) -> Tuple['torch.Tensor', 'torch.Tensor', 'torch.Tensor', 'torch.Tensor']:
|
||||
"""
|
||||
Retourne l'action, son log-prob, l'entropie et la valeur.
|
||||
|
||||
Args:
|
||||
x: Observations (batch_size, obs_dim)
|
||||
action: Si fourni, calcule le log-prob de cet action existante
|
||||
(pour la phase de mise à jour PPO)
|
||||
|
||||
Returns:
|
||||
Tuple (action, log_prob, entropy, value)
|
||||
"""
|
||||
logits, value = self.forward(x)
|
||||
dist = Categorical(logits=logits)
|
||||
|
||||
if action is None:
|
||||
action = dist.sample()
|
||||
|
||||
log_prob = dist.log_prob(action)
|
||||
entropy = dist.entropy()
|
||||
|
||||
return action, log_prob, entropy, value.squeeze(-1)
|
||||
|
||||
def get_value(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||
"""Retourne uniquement la valeur critic V(s)."""
|
||||
_, value = self.forward(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
|
||||
class PPOModel:
|
||||
"""
|
||||
Agent PPO (Proximal Policy Optimization) pour le trading.
|
||||
|
||||
Implémentation manuelle sans stable-baselines3, utilisant PyTorch pur.
|
||||
Suit le pseudo-code de Schulman et al. (2017) avec GAE.
|
||||
|
||||
Méthodes publiques :
|
||||
train(env, total_timesteps) — Lance l'entraînement
|
||||
predict(obs) — Retourne l'action et le log-prob
|
||||
save(symbol, timeframe) — Sauvegarde le modèle
|
||||
load(symbol, timeframe) — Charge un modèle existant (classmethod)
|
||||
|
||||
Args:
|
||||
obs_dim: Dimension de l'observation (défaut: 20)
|
||||
n_actions: Nombre d'actions discrètes (défaut: 3)
|
||||
lr: Taux d'apprentissage Adam (défaut: 3e-4)
|
||||
n_steps: Transitions par rollout avant mise à jour (défaut: 2048)
|
||||
n_epochs: Passes d'optimisation par rollout (défaut: 10)
|
||||
batch_size: Taille des mini-batchs (défaut: 64)
|
||||
gamma: Facteur d'actualisation (défaut: 0.99)
|
||||
gae_lambda: Lambda GAE (défaut: 0.95)
|
||||
clip_eps: Epsilon de clip PPO (défaut: 0.2)
|
||||
entropy_coef: Coefficient bonus entropie (défaut: 0.01)
|
||||
value_coef: Coefficient loss valeur (défaut: 0.5)
|
||||
max_grad_norm: Norme maximale du gradient (défaut: 0.5)
|
||||
"""
|
||||
|
||||
MODELS_DIR = MODELS_DIR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_dim: int = 20,
|
||||
n_actions: int = 3,
|
||||
lr: float = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
n_epochs: int = 10,
|
||||
batch_size: int = 64,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_eps: float = 0.2,
|
||||
entropy_coef: float = 0.01,
|
||||
value_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
):
|
||||
self.obs_dim = obs_dim
|
||||
self.n_actions = n_actions
|
||||
self.lr = lr
|
||||
self.n_steps = n_steps
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.clip_eps = clip_eps
|
||||
self.entropy_coef = entropy_coef
|
||||
self.value_coef = value_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.network: Optional['ActorCriticNetwork'] = None
|
||||
self.optimizer: Optional['optim.Adam'] = None
|
||||
self.is_trained: bool = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
self._init_network()
|
||||
|
||||
def _init_network(self) -> None:
|
||||
"""Initialise le réseau Actor-Critic et l'optimiseur Adam."""
|
||||
self.network = ActorCriticNetwork(self.obs_dim, self.n_actions)
|
||||
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr, eps=1e-5)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Entraînement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def train(
|
||||
self,
|
||||
env,
|
||||
total_timesteps: int = 100_000,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
) -> Dict:
|
||||
"""
|
||||
Entraîne l'agent PPO sur l'environnement de trading fourni.
|
||||
|
||||
Algorithme :
|
||||
Boucle principale :
|
||||
1. Collecte de n_steps transitions (rollout)
|
||||
2. Calcul des avantages GAE
|
||||
3. n_epochs passes d'optimisation sur mini-batchs mélangés
|
||||
Fin :
|
||||
Sauvegarde du modèle et retour des métriques
|
||||
|
||||
Args:
|
||||
env: Environnement TradingEnv
|
||||
total_timesteps: Nombre total de transitions à collecter
|
||||
symbol: Symbole pour la sauvegarde (ex: 'EURUSD')
|
||||
timeframe: Timeframe pour la sauvegarde (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Dict avec métriques : total_timesteps, n_updates, mean_reward,
|
||||
mean_ep_return, policy_loss, value_loss, entropy_loss
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||
|
||||
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
|
||||
torch.set_num_threads(4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.network.to(device)
|
||||
|
||||
logger.info(
|
||||
f"Début entraînement PPO {symbol}/{timeframe} — "
|
||||
f"total_timesteps={total_timesteps}, n_steps={self.n_steps}, device={device}"
|
||||
)
|
||||
|
||||
# ── Buffers de rollout ──────────────────────────────────────────────
|
||||
obs_buf = np.zeros((self.n_steps, self.obs_dim), dtype=np.float32)
|
||||
actions_buf = np.zeros((self.n_steps,), dtype=np.int64)
|
||||
rewards_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
dones_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
values_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
logprobs_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
|
||||
# ── Statistiques d'entraînement ─────────────────────────────────────
|
||||
all_episode_returns = []
|
||||
all_policy_losses = []
|
||||
all_value_losses = []
|
||||
all_entropy_losses = []
|
||||
ep_return = 0.0
|
||||
n_updates = 0
|
||||
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
|
||||
timestep = 0
|
||||
rollout_idx = 0
|
||||
|
||||
while timestep < total_timesteps:
|
||||
# ── Phase de collecte (rollout) ───────────────────────────────────
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
|
||||
action, log_prob, _, value = self.network.get_action_and_value(obs_t)
|
||||
|
||||
action_np = int(action.cpu().item())
|
||||
log_prob_np = float(log_prob.cpu().item())
|
||||
value_np = float(value.cpu().item())
|
||||
|
||||
next_obs, reward, terminated, truncated, _ = env.step(action_np)
|
||||
done = terminated or truncated
|
||||
|
||||
obs_buf[rollout_idx] = obs
|
||||
actions_buf[rollout_idx] = action_np
|
||||
rewards_buf[rollout_idx] = reward
|
||||
dones_buf[rollout_idx] = float(done)
|
||||
values_buf[rollout_idx] = value_np
|
||||
logprobs_buf[rollout_idx] = log_prob_np
|
||||
|
||||
ep_return += reward
|
||||
obs = next_obs
|
||||
timestep += 1
|
||||
rollout_idx += 1
|
||||
|
||||
if done:
|
||||
all_episode_returns.append(ep_return)
|
||||
ep_return = 0.0
|
||||
obs, _ = env.reset()
|
||||
|
||||
# ── Phase de mise à jour PPO ──────────────────────────────────────
|
||||
if rollout_idx == self.n_steps:
|
||||
# Calcul de la valeur du dernier état (bootstrap)
|
||||
with torch.no_grad():
|
||||
last_obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
|
||||
last_value = self.network.get_value(last_obs_t).cpu().numpy()
|
||||
|
||||
# Calcul des avantages GAE et des retours
|
||||
advantages, returns = self._compute_gae(
|
||||
rewards_buf, values_buf, dones_buf, float(last_value)
|
||||
)
|
||||
|
||||
# Conversion en tensors
|
||||
obs_t = torch.from_numpy(obs_buf).float().to(device)
|
||||
actions_t = torch.from_numpy(actions_buf).long().to(device)
|
||||
logprobs_t = torch.from_numpy(logprobs_buf).float().to(device)
|
||||
advs_t = torch.from_numpy(advantages).float().to(device)
|
||||
returns_t = torch.from_numpy(returns).float().to(device)
|
||||
|
||||
# Normalisation des avantages
|
||||
advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
|
||||
|
||||
# n_epochs passes d'optimisation
|
||||
n_samples = self.n_steps
|
||||
policy_losses_ep = []
|
||||
value_losses_ep = []
|
||||
entropy_losses_ep = []
|
||||
|
||||
self.network.train()
|
||||
for _ in range(self.n_epochs):
|
||||
indices = np.random.permutation(n_samples)
|
||||
|
||||
for start in range(0, n_samples, self.batch_size):
|
||||
end = start + self.batch_size
|
||||
batch = indices[start:end]
|
||||
if len(batch) < 2:
|
||||
continue
|
||||
|
||||
# Forward pass
|
||||
_, new_logprob, entropy, new_value = (
|
||||
self.network.get_action_and_value(
|
||||
obs_t[batch], actions_t[batch]
|
||||
)
|
||||
)
|
||||
|
||||
# Ratio de probabilité
|
||||
log_ratio = new_logprob - logprobs_t[batch]
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
||||
# Clip PPO
|
||||
adv_batch = advs_t[batch]
|
||||
pg_loss1 = -adv_batch * ratio
|
||||
pg_loss2 = -adv_batch * torch.clamp(
|
||||
ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps
|
||||
)
|
||||
policy_loss = torch.max(pg_loss1, pg_loss2).mean()
|
||||
|
||||
# Loss valeur (avec clip optionnel)
|
||||
value_loss = nn.functional.mse_loss(new_value, returns_t[batch])
|
||||
|
||||
# Bonus d'entropie (encourage l'exploration)
|
||||
entropy_loss = -entropy.mean()
|
||||
|
||||
# Loss totale
|
||||
total_loss = (
|
||||
policy_loss
|
||||
+ self.value_coef * value_loss
|
||||
+ self.entropy_coef * entropy_loss
|
||||
)
|
||||
|
||||
# Optimisation
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
self.network.parameters(), self.max_grad_norm
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
policy_losses_ep.append(float(policy_loss.item()))
|
||||
value_losses_ep.append(float(value_loss.item()))
|
||||
entropy_losses_ep.append(float(entropy_loss.item()))
|
||||
|
||||
all_policy_losses.extend(policy_losses_ep)
|
||||
all_value_losses.extend(value_losses_ep)
|
||||
all_entropy_losses.extend(entropy_losses_ep)
|
||||
n_updates += 1
|
||||
rollout_idx = 0
|
||||
|
||||
if n_updates % 10 == 0:
|
||||
mean_ret = float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0
|
||||
logger.info(
|
||||
f" Timestep {timestep}/{total_timesteps} | "
|
||||
f"Updates={n_updates} | MeanReturn(20ep)={mean_ret:.4f}"
|
||||
)
|
||||
|
||||
self.network.eval()
|
||||
self.network.to('cpu')
|
||||
self.is_trained = True
|
||||
|
||||
metrics = {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'total_timesteps': total_timesteps,
|
||||
'n_updates': n_updates,
|
||||
'mean_reward': float(np.mean(all_episode_returns)) if all_episode_returns else 0.0,
|
||||
'mean_ep_return': float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0,
|
||||
'n_episodes': len(all_episode_returns),
|
||||
'policy_loss': float(np.mean(all_policy_losses[-100:])) if all_policy_losses else 0.0,
|
||||
'value_loss': float(np.mean(all_value_losses[-100:])) if all_value_losses else 0.0,
|
||||
'entropy_loss': float(np.mean(all_entropy_losses[-100:])) if all_entropy_losses else 0.0,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'hyperparams': {
|
||||
'lr': self.lr,
|
||||
'n_steps': self.n_steps,
|
||||
'n_epochs': self.n_epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'gamma': self.gamma,
|
||||
'gae_lambda': self.gae_lambda,
|
||||
'clip_eps': self.clip_eps,
|
||||
'entropy_coef': self.entropy_coef,
|
||||
'value_coef': self.value_coef,
|
||||
},
|
||||
}
|
||||
self.metadata = metrics
|
||||
|
||||
logger.info(
|
||||
f"Entraînement PPO terminé. "
|
||||
f"MeanReturn={metrics['mean_ep_return']:.4f} | "
|
||||
f"Episodes={metrics['n_episodes']}"
|
||||
)
|
||||
return metrics
|
||||
|
||||
def predict(self, obs: np.ndarray) -> Tuple[int, float]:
|
||||
"""
|
||||
Prédit l'action optimale pour une observation donnée.
|
||||
|
||||
En mode inférence (deterministic=True) : argmax des logits.
|
||||
En mode exploration : sampling de la distribution.
|
||||
|
||||
Args:
|
||||
obs: Vecteur d'observation (obs_dim,) en float32
|
||||
|
||||
Returns:
|
||||
Tuple (action, log_prob) où action ∈ {0, 1, 2}
|
||||
"""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||
return 0, 0.0
|
||||
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||
action, log_prob, _, _ = self.network.get_action_and_value(obs_t)
|
||||
|
||||
return int(action.item()), float(log_prob.item())
|
||||
|
||||
def predict_deterministic(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Prédit l'action déterministe (argmax) sans exploration.
|
||||
|
||||
Args:
|
||||
obs: Vecteur d'observation (obs_dim,) en float32
|
||||
|
||||
Returns:
|
||||
Action déterministe (argmax des logits) ∈ {0, 1, 2}
|
||||
"""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||
return 0
|
||||
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||
logits, _ = self.network(obs_t)
|
||||
action = int(logits.argmax(dim=-1).item())
|
||||
|
||||
return action
|
||||
|
||||
def get_action_probas(self, obs: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Retourne les probabilités de chaque action via softmax des logits.
|
||||
|
||||
Args:
|
||||
obs: Vecteur d'observation (obs_dim,) en float32
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (3,) : [P(HOLD), P(LONG), P(SHORT)]
|
||||
"""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||
return np.array([1.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||
logits, _ = self.network(obs_t)
|
||||
probas = torch.softmax(logits, dim=-1).squeeze(0).numpy()
|
||||
|
||||
return probas.astype(np.float32)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Sauvegarde / Chargement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def save(self, symbol: str = 'EURUSD', timeframe: str = '1h') -> Path:
|
||||
"""
|
||||
Sauvegarde le modèle et ses métadonnées sur disque.
|
||||
|
||||
Format :
|
||||
{symbol}_{timeframe}.pt — state_dict + config PyTorch
|
||||
{symbol}_{timeframe}_meta.json — métadonnées JSON
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Path vers le fichier .pt sauvegardé
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch non disponible ou modèle non entraîné
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
if not self.is_trained or self.network is None:
|
||||
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||
|
||||
model_id = f"{symbol}_{timeframe}"
|
||||
model_path = self.MODELS_DIR / f"{model_id}.pt"
|
||||
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
torch.save(
|
||||
{
|
||||
'state_dict': self.network.state_dict(),
|
||||
'config': {
|
||||
'obs_dim': self.obs_dim,
|
||||
'n_actions': self.n_actions,
|
||||
'lr': self.lr,
|
||||
'n_steps': self.n_steps,
|
||||
'n_epochs': self.n_epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'gamma': self.gamma,
|
||||
'gae_lambda': self.gae_lambda,
|
||||
'clip_eps': self.clip_eps,
|
||||
'entropy_coef': self.entropy_coef,
|
||||
'value_coef': self.value_coef,
|
||||
'max_grad_norm': self.max_grad_norm,
|
||||
},
|
||||
'metadata': self.metadata,
|
||||
},
|
||||
model_path
|
||||
)
|
||||
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Modèle PPO sauvegardé : {model_path}")
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'PPOModel':
|
||||
"""
|
||||
Charge un modèle PPO existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance PPOModel prête à prédire
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch non disponible
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle PPO non trouvé : {model_path}")
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
cfg = checkpoint.get('config', {})
|
||||
|
||||
instance = cls(
|
||||
obs_dim = cfg.get('obs_dim', 20),
|
||||
n_actions = cfg.get('n_actions', 3),
|
||||
lr = cfg.get('lr', 3e-4),
|
||||
n_steps = cfg.get('n_steps', 2048),
|
||||
n_epochs = cfg.get('n_epochs', 10),
|
||||
batch_size = cfg.get('batch_size', 64),
|
||||
gamma = cfg.get('gamma', 0.99),
|
||||
gae_lambda = cfg.get('gae_lambda', 0.95),
|
||||
clip_eps = cfg.get('clip_eps', 0.2),
|
||||
entropy_coef = cfg.get('entropy_coef', 0.01),
|
||||
value_coef = cfg.get('value_coef', 0.5),
|
||||
max_grad_norm = cfg.get('max_grad_norm', 0.5),
|
||||
)
|
||||
|
||||
instance.network.load_state_dict(checkpoint['state_dict'])
|
||||
instance.network.eval()
|
||||
instance.is_trained = True
|
||||
instance.metadata = checkpoint.get('metadata', {})
|
||||
|
||||
logger.info(f"Modèle PPO chargé depuis {model_path}")
|
||||
return instance
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Calcul des avantages (GAE)
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _compute_gae(
|
||||
self,
|
||||
rewards: np.ndarray,
|
||||
values: np.ndarray,
|
||||
dones: np.ndarray,
|
||||
last_value: float,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Calcule les avantages généralisés (GAE) et les retours cibles.
|
||||
|
||||
Formule GAE :
|
||||
δt = rt + γ·V(st+1)·(1-dt) - V(st)
|
||||
Ât = δt + γλ·Ât+1·(1-dt)
|
||||
Rt = Ât + V(st)
|
||||
|
||||
Args:
|
||||
rewards: Récompenses du rollout (n_steps,)
|
||||
values: Valeurs estimées par le critic (n_steps,)
|
||||
dones: Indicateurs de fin d'épisode (n_steps,)
|
||||
last_value: Valeur bootstrap du dernier état (scalaire)
|
||||
|
||||
Returns:
|
||||
Tuple (advantages, returns) de forme (n_steps,)
|
||||
"""
|
||||
n = len(rewards)
|
||||
advantages = np.zeros(n, dtype=np.float32)
|
||||
last_gae = 0.0
|
||||
|
||||
for t in reversed(range(n)):
|
||||
if t == n - 1:
|
||||
next_non_terminal = 1.0 - dones[t]
|
||||
next_value = last_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - dones[t]
|
||||
next_value = values[t + 1]
|
||||
|
||||
delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
|
||||
last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
|
||||
advantages[t] = last_gae
|
||||
|
||||
returns = advantages + values
|
||||
return advantages, returns
|
||||
557
src/ml/rl/rl_strategy_model.py
Normal file
557
src/ml/rl/rl_strategy_model.py
Normal file
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
RLStrategyModel — Interface unifiée pour l'agent PPO de trading.
|
||||
|
||||
Interface identique à MLStrategyModel et CNNImageStrategyModel :
|
||||
model = RLStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv)
|
||||
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||
model.save()
|
||||
model.load(symbol, timeframe)
|
||||
model.list_trained_models()
|
||||
model.get_feature_importance() # retourne [] (RL = boîte noire)
|
||||
|
||||
Pipeline d'entraînement :
|
||||
1. Validation des données OHLCV (minimum 500 barres recommandées)
|
||||
2. Création de TradingEnv sur les données d'entraînement
|
||||
3. Entraînement PPO (total_timesteps configurable)
|
||||
4. Évaluation sur holdout (20% temporel final)
|
||||
5. Sauvegarde du modèle PPO + métadonnées JSON
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE, N_FEATURES
|
||||
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE, MODELS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Décodage des actions RL en signaux de trading
|
||||
# Action 0 (HOLD) → signal 0 (NEUTRAL)
|
||||
# Action 1 (LONG) → signal 1 (LONG)
|
||||
# Action 2 (SHORT) → signal -1 (SHORT)
|
||||
ACTION_TO_SIGNAL = {0: 0, 1: 1, 2: -1}
|
||||
|
||||
# Disponibilité globale du module RL
|
||||
RL_AVAILABLE = TORCH_AVAILABLE and GYM_AVAILABLE
|
||||
|
||||
|
||||
class RLStrategyModel:
|
||||
"""
|
||||
Modèle de trading basé sur un agent PPO (Reinforcement Learning).
|
||||
|
||||
L'agent apprend directement à maximiser le PnL via interaction avec
|
||||
un environnement de trading simulé (TradingEnv), sans supervision
|
||||
explicite de labels.
|
||||
|
||||
Avantages par rapport aux modèles supervisés :
|
||||
- Pas besoin de labels manuels (LONG/SHORT/NEUTRAL)
|
||||
- Optimise directement l'objectif de trading (PnL)
|
||||
- Apprend à gérer les positions (durée, timing de sortie)
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h', '15m')
|
||||
total_timesteps: Nombre de timesteps d'entraînement PPO (défaut: 100_000)
|
||||
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
|
||||
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
|
||||
min_confidence: Seuil de confiance minimum pour signaux (défaut: 0.50)
|
||||
train_ratio: Fraction des données pour l'entraînement (défaut: 0.80)
|
||||
initial_capital: Capital initial pour la simulation (défaut: 10_000)
|
||||
lr: Taux d'apprentissage Adam (défaut: 3e-4)
|
||||
n_steps: Transitions par rollout (défaut: 2048)
|
||||
n_epochs: Passes d'optimisation par rollout (défaut: 10)
|
||||
batch_size: Taille des mini-batchs (défaut: 64)
|
||||
"""
|
||||
|
||||
MODELS_DIR = MODELS_DIR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
total_timesteps: int = 100_000,
|
||||
sl_atr_mult: float = 1.0,
|
||||
tp_atr_mult: float = 2.0,
|
||||
min_confidence: float = 0.50,
|
||||
train_ratio: float = 0.80,
|
||||
initial_capital: float = 10_000.0,
|
||||
lr: float = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
n_epochs: int = 10,
|
||||
batch_size: int = 64,
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.total_timesteps = total_timesteps
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.min_confidence = min_confidence
|
||||
self.train_ratio = train_ratio
|
||||
self.initial_capital = initial_capital
|
||||
self.lr = lr
|
||||
self.n_steps = n_steps
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.ppo_model: Optional[PPOModel] = None
|
||||
self.is_trained = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Entraînement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def train(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Entraîne l'agent PPO sur les données OHLCV fournies.
|
||||
|
||||
Étapes :
|
||||
1. Validation des données (minimum 200 barres)
|
||||
2. Split temporel train/holdout (80/20 par défaut)
|
||||
3. Création du TradingEnv sur le jeu d'entraînement
|
||||
4. Entraînement PPO (total_timesteps)
|
||||
5. Évaluation sur le jeu de holdout
|
||||
6. Sauvegarde automatique du modèle
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV (colonnes : open/high/low/close/volume)
|
||||
Minimum 200 barres, idéalement 2000+ pour un bon apprentissage.
|
||||
|
||||
Returns:
|
||||
Dict avec métriques :
|
||||
- total_timesteps, n_episodes, mean_ep_return (entraînement)
|
||||
- eval_total_pnl, eval_return_pct, eval_sharpe (évaluation holdout)
|
||||
- n_samples, trained_at, error (si échec)
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||
if not GYM_AVAILABLE:
|
||||
return {'error': 'gymnasium non disponible — installer gymnasium>=0.26'}
|
||||
|
||||
logger.info(
|
||||
f"Début entraînement RLStrategyModel pour "
|
||||
f"{self.symbol}/{self.timeframe} — total_timesteps={self.total_timesteps}"
|
||||
)
|
||||
|
||||
# 1. Validation et normalisation des données
|
||||
df = data.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
required_cols = {'open', 'high', 'low', 'close', 'volume'}
|
||||
missing = required_cols - set(df.columns)
|
||||
if missing:
|
||||
return {'error': f'Colonnes manquantes : {missing}'}
|
||||
|
||||
df = df.dropna(subset=list(required_cols))
|
||||
|
||||
if len(df) < 200:
|
||||
return {
|
||||
'error': f'Données insuffisantes : {len(df)} barres (minimum 200)'
|
||||
}
|
||||
|
||||
# 2. Split temporel train / holdout (respecte l'ordre chronologique)
|
||||
n_train = int(len(df) * self.train_ratio)
|
||||
df_train = df.iloc[:n_train].reset_index(drop=True)
|
||||
df_holdout = df.iloc[n_train:].reset_index(drop=True)
|
||||
|
||||
logger.info(
|
||||
f" Split : train={len(df_train)} barres, holdout={len(df_holdout)} barres"
|
||||
)
|
||||
|
||||
if len(df_train) < 100:
|
||||
return {'error': f'Données d\'entraînement insuffisantes : {len(df_train)} barres'}
|
||||
|
||||
# 3. Création de l'environnement d'entraînement
|
||||
train_env = TradingEnv(
|
||||
df = df_train,
|
||||
sl_atr_mult = self.sl_atr_mult,
|
||||
tp_atr_mult = self.tp_atr_mult,
|
||||
initial_capital = self.initial_capital,
|
||||
)
|
||||
|
||||
# 4. Initialisation et entraînement PPO
|
||||
self.ppo_model = PPOModel(
|
||||
obs_dim = N_FEATURES,
|
||||
n_actions = 3,
|
||||
lr = self.lr,
|
||||
n_steps = min(self.n_steps, len(df_train) - 30), # Sécurité
|
||||
n_epochs = self.n_epochs,
|
||||
batch_size = self.batch_size,
|
||||
)
|
||||
|
||||
train_metrics = self.ppo_model.train(
|
||||
env = train_env,
|
||||
total_timesteps = self.total_timesteps,
|
||||
symbol = self.symbol,
|
||||
timeframe = self.timeframe,
|
||||
)
|
||||
|
||||
if 'error' in train_metrics:
|
||||
return train_metrics
|
||||
|
||||
self.is_trained = True
|
||||
|
||||
# 5. Évaluation sur holdout
|
||||
eval_metrics = {}
|
||||
if len(df_holdout) >= 50:
|
||||
eval_metrics = self._evaluate_on_holdout(df_holdout)
|
||||
logger.info(
|
||||
f" Évaluation holdout : "
|
||||
f"PnL={eval_metrics.get('eval_total_pnl', 0):.4f}, "
|
||||
f"Return={eval_metrics.get('eval_return_pct', 0):.2%}, "
|
||||
f"Sharpe≈{eval_metrics.get('eval_sharpe', 0):.2f}"
|
||||
)
|
||||
|
||||
# 6. Assemblage des métadonnées
|
||||
self.metadata = {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'n_samples': len(df),
|
||||
'n_train': len(df_train),
|
||||
'n_holdout': len(df_holdout),
|
||||
'total_timesteps': train_metrics.get('total_timesteps', self.total_timesteps),
|
||||
'n_episodes': train_metrics.get('n_episodes', 0),
|
||||
'n_updates': train_metrics.get('n_updates', 0),
|
||||
'mean_ep_return': train_metrics.get('mean_ep_return', 0.0),
|
||||
'policy_loss': train_metrics.get('policy_loss', 0.0),
|
||||
'value_loss': train_metrics.get('value_loss', 0.0),
|
||||
'entropy_loss': train_metrics.get('entropy_loss', 0.0),
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'min_confidence': self.min_confidence,
|
||||
'hyperparams': train_metrics.get('hyperparams', {}),
|
||||
**eval_metrics,
|
||||
}
|
||||
|
||||
# 7. Sauvegarde automatique
|
||||
self.save()
|
||||
|
||||
logger.info(
|
||||
f"Entraînement RL terminé pour {self.symbol}/{self.timeframe} — "
|
||||
f"N_episodes={self.metadata['n_episodes']}, "
|
||||
f"MeanReturn={self.metadata['mean_ep_return']:.4f}"
|
||||
)
|
||||
return self.metadata
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Prédiction
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def predict(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal de trading pour la dernière barre disponible.
|
||||
|
||||
L'agent évalue les 20 dernières barres (fenêtre lookback) et retourne
|
||||
son action déterministe (HOLD/LONG/SHORT) avec les probabilités associées.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV récent (minimum 25 barres pour les indicateurs)
|
||||
|
||||
Returns:
|
||||
Dict : {
|
||||
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL/HOLD),
|
||||
'confidence': float [0..1] — max(P_LONG, P_SHORT),
|
||||
'probas': {'hold': float, 'long': float, 'short': float},
|
||||
'tradeable': bool — confidence >= min_confidence et signal != 0,
|
||||
}
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'PyTorch non disponible'
|
||||
}
|
||||
if not self.is_trained or self.ppo_model is None:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'Modèle non entraîné — appeler train() d\'abord'
|
||||
}
|
||||
|
||||
df = data.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
if len(df) < 25:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': f'Pas assez de données : {len(df)} barres (minimum 25)'
|
||||
}
|
||||
|
||||
try:
|
||||
# Créer un environnement temporaire pour obtenir l'observation
|
||||
temp_env = TradingEnv(
|
||||
df = df,
|
||||
sl_atr_mult = self.sl_atr_mult,
|
||||
tp_atr_mult = self.tp_atr_mult,
|
||||
initial_capital = self.initial_capital,
|
||||
)
|
||||
|
||||
# Positionner l'environnement à la dernière barre
|
||||
temp_env._current_step = len(df) - 1
|
||||
obs = temp_env._get_observation()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': f'Erreur construction observation : {e}'
|
||||
}
|
||||
|
||||
# Prédiction déterministe (argmax des logits)
|
||||
action = self.ppo_model.predict_deterministic(obs)
|
||||
probas = self.ppo_model.get_action_probas(obs) # [P(HOLD), P(LONG), P(SHORT)]
|
||||
|
||||
signal = ACTION_TO_SIGNAL.get(action, 0)
|
||||
confidence = float(max(probas[1], probas[2])) # max(P_LONG, P_SHORT)
|
||||
|
||||
return {
|
||||
'signal': signal,
|
||||
'confidence': confidence,
|
||||
'probas': {
|
||||
'hold': float(probas[0]),
|
||||
'long': float(probas[1]),
|
||||
'short': float(probas[2]),
|
||||
},
|
||||
'tradeable': confidence >= self.min_confidence and signal != 0,
|
||||
}
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Sauvegarde / Chargement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def save(self) -> Path:
|
||||
"""
|
||||
Sauvegarde le modèle PPO et les métadonnées sur disque.
|
||||
|
||||
Returns:
|
||||
Path vers le fichier .pt sauvegardé
|
||||
|
||||
Raises:
|
||||
RuntimeError si le modèle n'est pas entraîné
|
||||
"""
|
||||
if not self.is_trained or self.ppo_model is None:
|
||||
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||
|
||||
# Sauvegarde du réseau PPO
|
||||
model_path = self.ppo_model.save(self.symbol, self.timeframe)
|
||||
|
||||
# Sauvegarde des métadonnées complètes (incluant les params RLStrategyModel)
|
||||
meta = self.metadata.copy()
|
||||
meta.update({
|
||||
'rl_config': {
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'min_confidence': self.min_confidence,
|
||||
'train_ratio': self.train_ratio,
|
||||
'initial_capital': self.initial_capital,
|
||||
'total_timesteps': self.total_timesteps,
|
||||
}
|
||||
})
|
||||
|
||||
meta_path = self.MODELS_DIR / f"{self.symbol}_{self.timeframe}_meta.json"
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(meta, f, indent=2, default=str)
|
||||
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'RLStrategyModel':
|
||||
"""
|
||||
Charge un modèle RL existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance RLStrategyModel prête à prédire
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch ou gymnasium non disponible
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
meta_path = MODELS_DIR / f"{symbol}_{timeframe}_meta.json"
|
||||
rl_config = {}
|
||||
if meta_path.exists():
|
||||
try:
|
||||
with open(meta_path) as f:
|
||||
saved_meta = json.load(f)
|
||||
rl_config = saved_meta.get('rl_config', {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
instance = cls(
|
||||
symbol = symbol,
|
||||
timeframe = timeframe,
|
||||
sl_atr_mult = rl_config.get('sl_atr_mult', 1.0),
|
||||
tp_atr_mult = rl_config.get('tp_atr_mult', 2.0),
|
||||
min_confidence = rl_config.get('min_confidence', 0.50),
|
||||
train_ratio = rl_config.get('train_ratio', 0.80),
|
||||
initial_capital = rl_config.get('initial_capital', 10_000.0),
|
||||
total_timesteps = rl_config.get('total_timesteps', 100_000),
|
||||
)
|
||||
|
||||
# Chargement du modèle PPO
|
||||
instance.ppo_model = PPOModel.load(symbol, timeframe)
|
||||
instance.is_trained = True
|
||||
|
||||
if meta_path.exists():
|
||||
try:
|
||||
with open(meta_path) as f:
|
||||
instance.metadata = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Modèle RL chargé depuis {MODELS_DIR}/{symbol}_{timeframe}.pt")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def list_trained_models() -> List[Dict]:
|
||||
"""
|
||||
Retourne la liste des modèles RL entraînés disponibles.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec symbol, timeframe, trained_at, n_samples,
|
||||
mean_ep_return, eval_return_pct
|
||||
"""
|
||||
if not MODELS_DIR.exists():
|
||||
return []
|
||||
|
||||
models = []
|
||||
for f in MODELS_DIR.glob("*_meta.json"):
|
||||
try:
|
||||
with open(f) as fp:
|
||||
meta = json.load(fp)
|
||||
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
|
||||
parts = stem.split('_', 1)
|
||||
models.append({
|
||||
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||
'trained_at': meta.get('trained_at', '?'),
|
||||
'n_samples': meta.get('n_samples', 0),
|
||||
'total_timesteps': meta.get('total_timesteps', 0),
|
||||
'n_episodes': meta.get('n_episodes', 0),
|
||||
'mean_ep_return': meta.get('mean_ep_return', 0.0),
|
||||
'eval_return_pct': meta.get('eval_return_pct', 0.0),
|
||||
'eval_sharpe': meta.get('eval_sharpe', 0.0),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return models
|
||||
|
||||
def get_feature_importance(self) -> List[Dict]:
|
||||
"""
|
||||
Retourne l'importance des features.
|
||||
|
||||
Note : Les agents RL (PPO) sont des boîtes noires — pas d'importance
|
||||
de feature interprétable. Retourne les noms des 20 features pour
|
||||
information, sans score numérique.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec feature_name et description
|
||||
"""
|
||||
feature_names = [
|
||||
{'rank': 1, 'feature': 'body_ratio', 'description': '(close-open)/ATR — corps de la bougie'},
|
||||
{'rank': 2, 'feature': 'amplitude', 'description': '(high-low)/ATR — amplitude de la bougie'},
|
||||
{'rank': 3, 'feature': 'momentum_1', 'description': '(close-close[-1])/ATR — momentum 1 barre'},
|
||||
{'rank': 4, 'feature': 'momentum_5', 'description': '(close-close[-5])/ATR — momentum 5 barres'},
|
||||
{'rank': 5, 'feature': 'rsi_norm', 'description': 'RSI(14)/100 — RSI normalisé'},
|
||||
{'rank': 6, 'feature': 'bb_position', 'description': '(close-BB_upper)/ATR — position vs Bollinger'},
|
||||
{'rank': 7, 'feature': 'bb_width', 'description': '(BB_upper-BB_lower)/ATR — largeur bandes'},
|
||||
{'rank': 8, 'feature': 'macd_norm', 'description': 'MACD/ATR — MACD normalisé'},
|
||||
{'rank': 9, 'feature': 'macd_signal_norm', 'description': 'Signal MACD/ATR — signal MACD normalisé'},
|
||||
{'rank': 10, 'feature': 'vol_ratio', 'description': 'Volume/MA20(Volume) — ratio volume'},
|
||||
{'rank': 11, 'feature': 'position', 'description': 'Position courante (-1=short, 0=flat, 1=long)'},
|
||||
{'rank': 12, 'feature': 'unrealized_pnl', 'description': 'PnL_non_réalisé/ATR — gain/perte courant'},
|
||||
{'rank': 13, 'feature': 'bars_in_trade', 'description': 'Durée du trade normalisée'},
|
||||
{'rank': 14, 'feature': 'hour', 'description': 'Heure de la journée normalisée'},
|
||||
{'rank': 15, 'feature': 'day_of_week', 'description': 'Jour de semaine normalisé'},
|
||||
{'rank': 16, 'feature': 'ema9_dist', 'description': '(close-EMA9)/ATR — distance EMA9'},
|
||||
{'rank': 17, 'feature': 'ema21_dist', 'description': '(close-EMA21)/ATR — distance EMA21'},
|
||||
{'rank': 18, 'feature': 'ema9_ema21_cross', 'description': '(EMA9-EMA21)/ATR — croisement EMA court'},
|
||||
{'rank': 19, 'feature': 'ema50_dist', 'description': '(close-EMA50)/ATR — distance EMA50'},
|
||||
{'rank': 20, 'feature': 'ema200_dist', 'description': '(close-EMA200)/ATR — distance EMA200'},
|
||||
]
|
||||
return feature_names
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Évaluation interne
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _evaluate_on_holdout(self, df_holdout: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Évalue l'agent entraîné sur le jeu de holdout.
|
||||
|
||||
Lance un épisode complet sur les données de holdout en mode déterministe
|
||||
(argmax des logits, pas de sampling). Retourne les métriques de performance.
|
||||
|
||||
Args:
|
||||
df_holdout: DataFrame OHLCV du jeu de holdout
|
||||
|
||||
Returns:
|
||||
Dict avec eval_total_pnl, eval_return_pct, eval_sharpe,
|
||||
eval_n_trades, eval_win_rate
|
||||
"""
|
||||
try:
|
||||
eval_env = TradingEnv(
|
||||
df = df_holdout,
|
||||
sl_atr_mult = self.sl_atr_mult,
|
||||
tp_atr_mult = self.tp_atr_mult,
|
||||
initial_capital = self.initial_capital,
|
||||
)
|
||||
|
||||
obs, _ = eval_env.reset()
|
||||
done = False
|
||||
rewards_hist = []
|
||||
n_trades = 0
|
||||
win_trades = 0
|
||||
|
||||
while not done:
|
||||
action = self.ppo_model.predict_deterministic(obs)
|
||||
obs, reward, terminated, truncated, info = eval_env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
if abs(reward) > 0.001: # Fermeture de position
|
||||
n_trades += 1
|
||||
if reward > 0:
|
||||
win_trades += 1
|
||||
|
||||
rewards_hist.append(reward)
|
||||
|
||||
stats = eval_env.get_episode_stats()
|
||||
win_rate = win_trades / max(n_trades, 1)
|
||||
rewards_arr = np.array(rewards_hist)
|
||||
sharpe = (
|
||||
float(rewards_arr.mean() / (rewards_arr.std() + 1e-8) * np.sqrt(252))
|
||||
if len(rewards_arr) > 1 else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
'eval_total_pnl': float(stats.get('total_pnl', 0.0)),
|
||||
'eval_return_pct': float(stats.get('return_pct', 0.0)),
|
||||
'eval_sharpe': float(sharpe),
|
||||
'eval_n_trades': n_trades,
|
||||
'eval_win_rate': float(win_rate),
|
||||
'eval_n_steps': int(stats.get('n_steps', 0)),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Évaluation holdout échouée : {e}")
|
||||
return {
|
||||
'eval_total_pnl': 0.0,
|
||||
'eval_return_pct': 0.0,
|
||||
'eval_sharpe': 0.0,
|
||||
'eval_n_trades': 0,
|
||||
'eval_win_rate': 0.0,
|
||||
}
|
||||
524
src/ml/rl/trading_env.py
Normal file
524
src/ml/rl/trading_env.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""
|
||||
TradingEnv — Environnement Gymnasium pour l'agent RL de trading.
|
||||
|
||||
Conforme à l'API gymnasium (reset/step) et adapté aux marchés forex.
|
||||
|
||||
Espace d'observation : 20 features normalisées (fenêtre glissante de 20 barres)
|
||||
Espace d'action : Discrete(3) → {0=HOLD, 1=LONG, 2=SHORT}
|
||||
Récompense : PnL réalisé + pénalité drawdown + bonus Sharpe
|
||||
|
||||
L'environnement gère :
|
||||
- Les transitions de position (flat → long, flat → short, inversions)
|
||||
- Le stop-loss / take-profit automatiques basés sur l'ATR
|
||||
- La normalisation des observations par l'ATR
|
||||
- La fenêtre glissante de 20 barres (lookback)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Imports conditionnels gymnasium / gym
|
||||
try:
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
GYM_AVAILABLE = True
|
||||
GYM_MODULE = 'gymnasium'
|
||||
except ImportError:
|
||||
try:
|
||||
import gym
|
||||
from gym import spaces
|
||||
GYM_AVAILABLE = True
|
||||
GYM_MODULE = 'gym'
|
||||
except ImportError:
|
||||
gym = None
|
||||
spaces = None
|
||||
GYM_AVAILABLE = False
|
||||
GYM_MODULE = None
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Actions
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
ACTION_HOLD = 0
|
||||
ACTION_LONG = 1
|
||||
ACTION_SHORT = 2
|
||||
|
||||
# Nombre de features dans le vecteur d'observation
|
||||
N_FEATURES = 20
|
||||
|
||||
# Fenêtre lookback (nombre de barres passées incluses dans l'observation)
|
||||
LOOKBACK = 20
|
||||
|
||||
|
||||
class TradingEnv:
|
||||
"""
|
||||
Environnement de trading conforme à l'API gymnasium pour l'agent PPO.
|
||||
|
||||
Chaque step() représente la décision de trading à la fermeture d'une bougie.
|
||||
L'agent reçoit 20 features normalisées décrivant le contexte de marché et
|
||||
sa position courante, puis choisit HOLD / LONG / SHORT.
|
||||
|
||||
Gestion des positions :
|
||||
- Une seule position à la fois (pas de pyramiding)
|
||||
- Inversion directe possible (SHORT → LONG sans passer par FLAT)
|
||||
- SL/TP ATR-based (sl_atr_mult×ATR et tp_atr_mult×ATR)
|
||||
|
||||
Récompense :
|
||||
- PnL réalisé à la clôture de position (en multiple d'ATR)
|
||||
- Pénalité proportionnelle au drawdown courant
|
||||
- Pas de bonus pour maintenir une position (évite le surtrading)
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV (colonnes minuscules : open/high/low/close/volume)
|
||||
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
|
||||
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
|
||||
atr_period: Période de calcul de l'ATR (défaut: 14)
|
||||
initial_capital: Capital initial (pour calcul drawdown) (défaut: 10000)
|
||||
drawdown_penalty: Coefficient de pénalité drawdown (défaut: 0.1)
|
||||
"""
|
||||
|
||||
metadata = {'render_modes': []}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
sl_atr_mult: float = 1.0,
|
||||
tp_atr_mult: float = 2.0,
|
||||
atr_period: int = 14,
|
||||
initial_capital: float = 10_000.0,
|
||||
drawdown_penalty: float = 0.1,
|
||||
):
|
||||
if not GYM_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"gymnasium ou gym requis — installer gymnasium>=0.26 ou gym>=0.21"
|
||||
)
|
||||
|
||||
self.df = df.copy().reset_index(drop=True)
|
||||
self.df.columns = [c.lower() for c in self.df.columns]
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.atr_period = atr_period
|
||||
self.initial_capital = initial_capital
|
||||
self.drawdown_penalty = drawdown_penalty
|
||||
|
||||
# Calcul de l'ATR sur toute la série (optimisation : évite le recalcul dans step)
|
||||
self._atr_series = self._compute_atr_series()
|
||||
|
||||
# Calcul des EMAs sur toute la série
|
||||
self._ema9 = self.df['close'].ewm(span=9, adjust=False).mean()
|
||||
self._ema21 = self.df['close'].ewm(span=21, adjust=False).mean()
|
||||
self._ema50 = self.df['close'].ewm(span=50, adjust=False).mean()
|
||||
self._ema200 = self.df['close'].ewm(span=200, adjust=False).mean()
|
||||
|
||||
# Calcul des Bandes de Bollinger (20, 2σ)
|
||||
bb_ma = self.df['close'].rolling(20).mean()
|
||||
bb_std = self.df['close'].rolling(20).std()
|
||||
self._bb_upper = bb_ma + 2 * bb_std
|
||||
self._bb_lower = bb_ma - 2 * bb_std
|
||||
|
||||
# Calcul du RSI(14)
|
||||
self._rsi = self._compute_rsi(period=14)
|
||||
|
||||
# Calcul du MACD (12, 26, 9)
|
||||
ema12 = self.df['close'].ewm(span=12, adjust=False).mean()
|
||||
ema26 = self.df['close'].ewm(span=26, adjust=False).mean()
|
||||
self._macd = ema12 - ema26
|
||||
self._macd_sig = self._macd.ewm(span=9, adjust=False).mean()
|
||||
|
||||
# Volume MA(20) pour normalisation
|
||||
self._vol_ma20 = self.df['volume'].rolling(20).mean().fillna(
|
||||
self.df['volume'].mean()
|
||||
)
|
||||
|
||||
# Espaces gymnasium
|
||||
self.observation_space = spaces.Box(
|
||||
low = -10.0,
|
||||
high = 10.0,
|
||||
shape = (N_FEATURES,),
|
||||
dtype = np.float32,
|
||||
)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
|
||||
# État interne (initialisé dans reset())
|
||||
self._current_step = LOOKBACK
|
||||
self._position = 0 # -1=short, 0=flat, 1=long
|
||||
self._entry_price = 0.0
|
||||
self._entry_atr = 0.0
|
||||
self._bars_in_trade = 0
|
||||
self._capital = initial_capital
|
||||
self._peak_capital = initial_capital
|
||||
self._total_pnl = 0.0
|
||||
self._pnl_history = [] # Pour calcul Sharpe en fin d'épisode
|
||||
self._done = False
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Interface gymnasium
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
"""
|
||||
Réinitialise l'environnement au début d'un épisode.
|
||||
|
||||
Args:
|
||||
seed: Graine aléatoire (ignorée ici, données déterministes)
|
||||
options: Options additionnelles (non utilisées)
|
||||
|
||||
Returns:
|
||||
Tuple (observation, info) conforme gymnasium
|
||||
"""
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
self._current_step = LOOKBACK
|
||||
self._position = 0
|
||||
self._entry_price = 0.0
|
||||
self._entry_atr = 0.0
|
||||
self._bars_in_trade = 0
|
||||
self._capital = self.initial_capital
|
||||
self._peak_capital = self.initial_capital
|
||||
self._total_pnl = 0.0
|
||||
self._pnl_history = []
|
||||
self._done = False
|
||||
|
||||
obs = self._get_observation()
|
||||
info = {}
|
||||
return obs, info
|
||||
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
|
||||
"""
|
||||
Exécute une action et retourne la transition.
|
||||
|
||||
Args:
|
||||
action: 0=HOLD, 1=LONG, 2=SHORT
|
||||
|
||||
Returns:
|
||||
Tuple (observation, reward, terminated, truncated, info) conforme gymnasium
|
||||
"""
|
||||
if self._done:
|
||||
obs, info = self.reset()
|
||||
return obs, 0.0, True, False, info
|
||||
|
||||
reward = 0.0
|
||||
info = {}
|
||||
|
||||
current_close = float(self.df['close'].iloc[self._current_step])
|
||||
current_atr = float(self._atr_series.iloc[self._current_step])
|
||||
if current_atr <= 0:
|
||||
current_atr = float(self.df['close'].iloc[self._current_step]) * 0.001
|
||||
|
||||
# ── Vérification SL/TP avant d'appliquer la nouvelle action ──────────
|
||||
if self._position != 0:
|
||||
reward += self._check_sl_tp(current_close, current_atr)
|
||||
|
||||
# ── Application de la nouvelle action ─────────────────────────────────
|
||||
desired_position = self._action_to_position(action)
|
||||
|
||||
if desired_position != self._position:
|
||||
# Fermeture de la position courante (si ouverte)
|
||||
if self._position != 0:
|
||||
close_reward = self._close_position(current_close, current_atr)
|
||||
reward += close_reward
|
||||
|
||||
# Ouverture d'une nouvelle position (si pas HOLD)
|
||||
if desired_position != 0:
|
||||
self._open_position(desired_position, current_close, current_atr)
|
||||
else:
|
||||
# Même position : incrémenter le compteur de barres
|
||||
if self._position != 0:
|
||||
self._bars_in_trade += 1
|
||||
|
||||
# ── Pénalité drawdown ─────────────────────────────────────────────────
|
||||
if self._capital > self._peak_capital:
|
||||
self._peak_capital = self._capital
|
||||
drawdown = (self._peak_capital - self._capital) / max(self._peak_capital, 1.0)
|
||||
if drawdown > 0:
|
||||
reward -= self.drawdown_penalty * drawdown
|
||||
|
||||
# ── Avance d'une barre ────────────────────────────────────────────────
|
||||
self._current_step += 1
|
||||
terminated = self._current_step >= len(self.df) - 1
|
||||
truncated = False
|
||||
|
||||
if terminated:
|
||||
# Fermeture forcée de la position à la fin de l'épisode
|
||||
if self._position != 0:
|
||||
final_close = float(self.df['close'].iloc[-1])
|
||||
final_atr = float(self._atr_series.iloc[-1])
|
||||
reward += self._close_position(final_close, final_atr)
|
||||
self._done = True
|
||||
|
||||
obs = self._get_observation()
|
||||
|
||||
# Sauvegarde de la récompense pour calcul Sharpe
|
||||
self._pnl_history.append(reward)
|
||||
|
||||
info = {
|
||||
'position': self._position,
|
||||
'capital': self._capital,
|
||||
'drawdown': drawdown,
|
||||
'bars_in_trade': self._bars_in_trade,
|
||||
'step': self._current_step,
|
||||
}
|
||||
|
||||
return obs, float(reward), terminated, truncated, info
|
||||
|
||||
def render(self):
|
||||
"""Affichage minimal de l'état courant."""
|
||||
pos_str = {0: 'FLAT', 1: 'LONG', -1: 'SHORT'}.get(self._position, '?')
|
||||
logger.debug(
|
||||
f"Step={self._current_step} | Pos={pos_str} | "
|
||||
f"Capital={self._capital:.2f} | PnL={self._total_pnl:.4f}"
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Observation
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
"""
|
||||
Construit le vecteur d'observation de 20 features normalisées.
|
||||
|
||||
Toutes les features de prix sont normalisées par l'ATR courant pour
|
||||
rendre l'observation invariante à l'échelle du prix.
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (20,) avec dtype float32
|
||||
"""
|
||||
i = self._current_step
|
||||
# Sécurité : ne pas dépasser les bornes
|
||||
i = max(LOOKBACK, min(i, len(self.df) - 1))
|
||||
|
||||
close = float(self.df['close'].iloc[i])
|
||||
open_ = float(self.df['open'].iloc[i])
|
||||
high = float(self.df['high'].iloc[i])
|
||||
low = float(self.df['low'].iloc[i])
|
||||
vol = float(self.df['volume'].iloc[i])
|
||||
|
||||
atr = float(self._atr_series.iloc[i])
|
||||
if atr <= 0:
|
||||
atr = close * 0.001
|
||||
|
||||
# Close i-1 et i-5 pour le momentum
|
||||
close_1 = float(self.df['close'].iloc[max(0, i - 1)])
|
||||
close_5 = float(self.df['close'].iloc[max(0, i - 5)])
|
||||
|
||||
rsi = float(self._rsi.iloc[i]) if not np.isnan(self._rsi.iloc[i]) else 50.0
|
||||
bb_upper = float(self._bb_upper.iloc[i]) if not np.isnan(self._bb_upper.iloc[i]) else close
|
||||
bb_lower = float(self._bb_lower.iloc[i]) if not np.isnan(self._bb_lower.iloc[i]) else close
|
||||
macd = float(self._macd.iloc[i]) if not np.isnan(self._macd.iloc[i]) else 0.0
|
||||
macd_signal = float(self._macd_sig.iloc[i]) if not np.isnan(self._macd_sig.iloc[i]) else 0.0
|
||||
vol_ma20 = float(self._vol_ma20.iloc[i])
|
||||
ema9 = float(self._ema9.iloc[i])
|
||||
ema21 = float(self._ema21.iloc[i])
|
||||
ema50 = float(self._ema50.iloc[i])
|
||||
ema200 = float(self._ema200.iloc[i])
|
||||
|
||||
vol_ratio = vol / max(vol_ma20, 1.0)
|
||||
|
||||
# PnL non-réalisé courant
|
||||
if self._position != 0 and self._entry_price > 0:
|
||||
unrealized = self._position * (close - self._entry_price)
|
||||
else:
|
||||
unrealized = 0.0
|
||||
|
||||
features = np.array([
|
||||
# Prix relatifs normalisés par ATR
|
||||
(close - open_) / atr, # 0 : corps de la bougie
|
||||
(high - low) / atr, # 1 : amplitude bougie (volatilité relative)
|
||||
(close - close_1) / atr, # 2 : momentum 1 barre
|
||||
(close - close_5) / atr, # 3 : momentum 5 barres
|
||||
# Indicateurs techniques normalisés
|
||||
rsi / 100.0, # 4 : RSI [0..1]
|
||||
(close - bb_upper) / atr, # 5 : position vs BB haute
|
||||
(bb_upper - bb_lower) / atr, # 6 : largeur des bandes de Bollinger
|
||||
macd / atr, # 7 : MACD normalisé
|
||||
macd_signal / atr, # 8 : signal MACD normalisé
|
||||
# Volume
|
||||
np.clip(vol_ratio, 0.0, 5.0), # 9 : ratio volume / MA20 (plafonné à 5)
|
||||
# Position et état du trade
|
||||
float(self._position), # 10: position courante (-1, 0, 1)
|
||||
unrealized / atr, # 11: PnL non-réalisé en ATR
|
||||
min(self._bars_in_trade / 50.0, 1.0), # 12: durée du trade (normalisée)
|
||||
# Temporel (si index datetime)
|
||||
self._get_hour(i), # 13: heure normalisée [0..1]
|
||||
self._get_dow(i), # 14: jour de semaine [0..1]
|
||||
# Moyennes mobiles (distance close - EMA, normalisée par ATR)
|
||||
(close - ema9) / atr, # 15: distance EMA9
|
||||
(close - ema21) / atr, # 16: distance EMA21
|
||||
(ema9 - ema21) / atr, # 17: croisement EMA9/21
|
||||
(close - ema50) / atr, # 18: distance EMA50
|
||||
(close - ema200) / atr, # 19: distance EMA200
|
||||
], dtype=np.float32)
|
||||
|
||||
# Clip pour éviter les valeurs extrêmes (protection robustesse)
|
||||
features = np.clip(features, -10.0, 10.0)
|
||||
# Remplacement des NaN résiduels
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||
|
||||
return features
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Gestion des positions
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _open_position(self, direction: int, price: float, atr: float) -> None:
|
||||
"""
|
||||
Ouvre une position dans la direction donnée.
|
||||
|
||||
Args:
|
||||
direction: 1=LONG, -1=SHORT
|
||||
price: Prix d'entrée (close de la barre courante)
|
||||
atr: ATR courant pour le calcul SL/TP
|
||||
"""
|
||||
self._position = direction
|
||||
self._entry_price = price
|
||||
self._entry_atr = atr
|
||||
self._bars_in_trade = 0
|
||||
|
||||
def _close_position(self, price: float, atr: float) -> float:
|
||||
"""
|
||||
Ferme la position courante et calcule la récompense.
|
||||
|
||||
La récompense est le PnL en multiple d'ATR (normalisé et sans unité).
|
||||
Cette normalisation rend la récompense comparable entre différents
|
||||
symboles et périodes de volatilité.
|
||||
|
||||
Args:
|
||||
price: Prix de sortie
|
||||
atr: ATR courant (pour normalisation)
|
||||
|
||||
Returns:
|
||||
Récompense (PnL normalisé par ATR)
|
||||
"""
|
||||
if self._position == 0 or self._entry_price <= 0:
|
||||
self._position = 0
|
||||
self._bars_in_trade = 0
|
||||
return 0.0
|
||||
|
||||
raw_pnl = self._position * (price - self._entry_price)
|
||||
|
||||
# Normalisation par l'ATR d'entrée pour une récompense sans unité
|
||||
entry_atr = max(self._entry_atr, price * 0.0001)
|
||||
reward = raw_pnl / entry_atr
|
||||
|
||||
# Mise à jour du capital (simulation simplifiée avec 1 lot fixe)
|
||||
self._capital += raw_pnl
|
||||
self._total_pnl += raw_pnl
|
||||
|
||||
self._position = 0
|
||||
self._entry_price = 0.0
|
||||
self._bars_in_trade = 0
|
||||
|
||||
return float(reward)
|
||||
|
||||
def _check_sl_tp(self, current_close: float, current_atr: float) -> float:
|
||||
"""
|
||||
Vérifie si le SL ou TP est atteint pour la position courante.
|
||||
|
||||
Utilise l'ATR d'entrée pour les niveaux SL/TP (stabilité).
|
||||
Si le SL ou TP est touché, la position est fermée.
|
||||
|
||||
Args:
|
||||
current_close: Prix de clôture de la barre courante
|
||||
current_atr: ATR courant
|
||||
|
||||
Returns:
|
||||
Récompense si fermeture forcée, 0.0 sinon
|
||||
"""
|
||||
if self._position == 0 or self._entry_price <= 0:
|
||||
return 0.0
|
||||
|
||||
entry_atr = max(self._entry_atr, self._entry_price * 0.0001)
|
||||
sl_dist = self.sl_atr_mult * entry_atr
|
||||
tp_dist = self.tp_atr_mult * entry_atr
|
||||
|
||||
if self._position == 1: # LONG
|
||||
sl_level = self._entry_price - sl_dist
|
||||
tp_level = self._entry_price + tp_dist
|
||||
if current_close <= sl_level or current_close >= tp_level:
|
||||
return self._close_position(current_close, current_atr)
|
||||
|
||||
elif self._position == -1: # SHORT
|
||||
sl_level = self._entry_price + sl_dist
|
||||
tp_level = self._entry_price - tp_dist
|
||||
if current_close >= sl_level or current_close <= tp_level:
|
||||
return self._close_position(current_close, current_atr)
|
||||
|
||||
return 0.0
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _action_to_position(action: int) -> int:
|
||||
"""Convertit l'action discrète en direction de position."""
|
||||
return {ACTION_HOLD: 0, ACTION_LONG: 1, ACTION_SHORT: -1}.get(action, 0)
|
||||
|
||||
def _get_hour(self, i: int) -> float:
|
||||
"""Retourne l'heure normalisée [0..1] depuis l'index, ou 0.5 si non datetime."""
|
||||
try:
|
||||
idx = self.df.index[i]
|
||||
if hasattr(idx, 'hour'):
|
||||
return float(idx.hour) / 24.0
|
||||
except Exception:
|
||||
pass
|
||||
return 0.5
|
||||
|
||||
def _get_dow(self, i: int) -> float:
|
||||
"""Retourne le jour de semaine normalisé [0..1] depuis l'index, ou 0.5 si non datetime."""
|
||||
try:
|
||||
idx = self.df.index[i]
|
||||
if hasattr(idx, 'dayofweek'):
|
||||
return float(idx.dayofweek) / 5.0
|
||||
except Exception:
|
||||
pass
|
||||
return 0.5
|
||||
|
||||
def _compute_atr_series(self) -> pd.Series:
|
||||
"""Calcule l'ATR(14) sur toute la série OHLCV."""
|
||||
h = self.df['high']
|
||||
l = self.df['low']
|
||||
pc = self.df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.ewm(span=self.atr_period, adjust=False).mean()
|
||||
# Remplir les premières valeurs NaN par la plage high-low
|
||||
atr = atr.fillna(h - l)
|
||||
return atr
|
||||
|
||||
def _compute_rsi(self, period: int = 14) -> pd.Series:
|
||||
"""Calcule le RSI sur toute la série de closes."""
|
||||
delta = self.df['close'].diff()
|
||||
gain = delta.clip(lower=0).ewm(span=period, adjust=False).mean()
|
||||
loss = (-delta.clip(upper=0)).ewm(span=period, adjust=False).mean()
|
||||
rs = gain / loss.replace(0, np.nan)
|
||||
rsi = 100.0 - (100.0 / (1.0 + rs))
|
||||
return rsi.fillna(50.0)
|
||||
|
||||
def get_episode_stats(self) -> dict:
|
||||
"""
|
||||
Retourne les statistiques de l'épisode courant.
|
||||
|
||||
Utile pour le logging et l'évaluation du modèle PPO.
|
||||
|
||||
Returns:
|
||||
Dict avec total_pnl, n_steps, capital, Sharpe approximatif
|
||||
"""
|
||||
rewards = np.array(self._pnl_history)
|
||||
if len(rewards) > 1 and rewards.std() > 0:
|
||||
sharpe = float(rewards.mean() / rewards.std() * np.sqrt(252))
|
||||
else:
|
||||
sharpe = 0.0
|
||||
|
||||
return {
|
||||
'total_pnl': self._total_pnl,
|
||||
'capital': self._capital,
|
||||
'return_pct': (self._capital - self.initial_capital) / self.initial_capital,
|
||||
'n_steps': self._current_step,
|
||||
'sharpe_approx': sharpe,
|
||||
}
|
||||
10
src/strategies/rl_driven/__init__.py
Normal file
10
src/strategies/rl_driven/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Module RL-Driven Strategy — Stratégie de trading pilotée par un agent PPO.
|
||||
|
||||
L'agent PPO apprend à trader par renforcement : il maximise directement
|
||||
le PnL sans supervision explicite (pas de labels LONG/SHORT/NEUTRAL).
|
||||
"""
|
||||
|
||||
from src.strategies.rl_driven.rl_strategy import RLDrivenStrategy, RL_AVAILABLE
|
||||
|
||||
__all__ = ['RLDrivenStrategy', 'RL_AVAILABLE']
|
||||
259
src/strategies/rl_driven/rl_strategy.py
Normal file
259
src/strategies/rl_driven/rl_strategy.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
RL-Driven Strategy — Stratégie pilotée par un agent de Reinforcement Learning (PPO).
|
||||
|
||||
Contrairement aux stratégies supervisées (XGBoost, CNN) qui apprennent depuis des
|
||||
labels pré-calculés, cette stratégie utilise un agent PPO (Proximal Policy
|
||||
Optimization) entraîné via interaction directe avec un environnement de trading
|
||||
(TradingEnv). L'agent apprend à maximiser la récompense cumulative (profit ajusté
|
||||
par le risque) sans avoir besoin de labels explicites.
|
||||
|
||||
Fonctionnement :
|
||||
1. Le modèle RLStrategyModel est chargé depuis le disque
|
||||
(entraîné via POST /trading/train-rl)
|
||||
2. À chaque barre, les seq_len dernières bougies sont fournies à l'agent
|
||||
3. L'agent PPO retourne une action : LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||
avec un score de confiance basé sur les probabilités de la politique
|
||||
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||
|
||||
Intégration :
|
||||
- Compatible avec StrategyEngine (même interface que CNNImageDrivenStrategy)
|
||||
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||
- Requiert PyTorch (CPU ou CUDA)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||
|
||||
# Import conditionnel — RL_AVAILABLE reste False si PyTorch ou stable-baselines3
|
||||
# ne sont pas installés dans le container
|
||||
try:
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||
RL_AVAILABLE = True
|
||||
except ImportError:
|
||||
RLStrategyModel = None
|
||||
RL_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RLDrivenStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading pilotée par un agent PPO (Reinforcement Learning).
|
||||
|
||||
L'agent apprend à maximiser le profit cumulatif en interagissant avec un
|
||||
environnement de trading simulé (TradingEnv). Aucun label supervisé n'est
|
||||
nécessaire : la récompense est définie par le P&L réalisé et les pénalités
|
||||
de risque (drawdown, over-trading).
|
||||
|
||||
Args:
|
||||
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||
|
||||
Config keys supplémentaires (optionnelles) :
|
||||
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||
seq_len: Fenêtre d'observation RL (barres) (défaut: 20)
|
||||
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'rl_driven'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.min_confidence = config.get('min_confidence', 0.55)
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
self.seq_len = config.get('seq_len', 20)
|
||||
|
||||
self.rl_model: Optional['RLStrategyModel'] = None
|
||||
|
||||
if not RL_AVAILABLE:
|
||||
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
|
||||
return
|
||||
|
||||
# Tentative de chargement automatique du modèle existant
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_model()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal de trading via l'agent PPO.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Signal si l'agent est confiant, None sinon
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
logger.debug("RL Strategy : PyTorch / stable-baselines3 non disponible, aucun signal")
|
||||
return None
|
||||
|
||||
if self.rl_model is None or not self.rl_model.is_trained:
|
||||
logger.debug("RL Strategy : modèle non chargé, aucun signal")
|
||||
return None
|
||||
|
||||
if len(market_data) < self.seq_len:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.rl_model.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"RL Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour le calcul des niveaux SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol = self.symbol,
|
||||
direction = direction,
|
||||
entry_price = last_close,
|
||||
stop_loss = stop_loss,
|
||||
take_profit = take_profit,
|
||||
confidence = confidence,
|
||||
timestamp = datetime.now(timezone.utc),
|
||||
strategy = self.STRATEGY_NAME,
|
||||
metadata = {
|
||||
'probas': result.get('probas', {}),
|
||||
'seq_len': self.seq_len,
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'avg_reward': result.get('avg_reward'),
|
||||
'total_timesteps': result.get('total_timesteps'),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"RL Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — l'agent RL travaille sur les OHLCV bruts."""
|
||||
return data
|
||||
|
||||
def update_params(self, params: Dict) -> None:
|
||||
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||
if 'min_confidence' in params:
|
||||
self.min_confidence = params['min_confidence']
|
||||
if self.rl_model:
|
||||
self.rl_model.min_confidence = params['min_confidence']
|
||||
if 'tp_atr_mult' in params:
|
||||
self.tp_atr_mult = params['tp_atr_mult']
|
||||
if 'sl_atr_mult' in params:
|
||||
self.sl_atr_mult = params['sl_atr_mult']
|
||||
if 'seq_len' in params:
|
||||
self.seq_len = params['seq_len']
|
||||
logger.info(f"RL Strategy params mis à jour : {params}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Gestion du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Charge un modèle RL depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (défaut: self.symbol)
|
||||
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||
|
||||
Returns:
|
||||
True si chargement réussi
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
|
||||
return False
|
||||
|
||||
sym = symbol or self.symbol
|
||||
tf = timeframe or self.config.timeframe
|
||||
try:
|
||||
self.rl_model = RLStrategyModel.load(sym, tf)
|
||||
logger.info(f"Modèle RL chargé : {sym}/{tf}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.info(f"Aucun modèle RL trouvé pour {sym}/{tf}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement modèle RL : {e}")
|
||||
return False
|
||||
|
||||
def attach_model(self, model: 'RLStrategyModel') -> None:
|
||||
"""Attache directement un modèle RL (après entraînement via API)."""
|
||||
self.rl_model = model
|
||||
self.symbol = model.symbol
|
||||
logger.info(f"Modèle RL attaché : {model.symbol}/{model.timeframe}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Retourne True si le modèle RL est chargé et entraîné."""
|
||||
if not RL_AVAILABLE:
|
||||
return False
|
||||
return self.rl_model is not None and self.rl_model.is_trained
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Retourne les métadonnées du modèle RL actif."""
|
||||
if not RL_AVAILABLE:
|
||||
return {'status': 'PyTorch / stable-baselines3 non disponible'}
|
||||
if not self.is_ready():
|
||||
return {'status': 'non entraîné'}
|
||||
meta = self.rl_model.metadata.copy()
|
||||
meta['is_ready'] = True
|
||||
meta['seq_len'] = self.seq_len
|
||||
return meta
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _try_load_model(self) -> None:
|
||||
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||
try:
|
||||
self.load_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
Reference in New Issue
Block a user