Compare commits
7 Commits
8732acf3d0
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff8d58e1aa | ||
|
|
80e1308a1e | ||
|
|
e9d4c440d9 | ||
|
|
6fd68af47a | ||
|
|
7af7248b4d | ||
|
|
d245d7d8f4 | ||
|
|
acc3338213 |
@@ -20,3 +20,16 @@ prometheus-client==0.19.0
|
|||||||
|
|
||||||
# Notifications
|
# Notifications
|
||||||
python-telegram-bot==20.7
|
python-telegram-bot==20.7
|
||||||
|
|
||||||
|
# ML — requis pour MLDrivenStrategy (entraînement et prédiction dans l'API)
|
||||||
|
scikit-learn==1.3.2
|
||||||
|
xgboost==2.0.3
|
||||||
|
lightgbm==4.1.0
|
||||||
|
joblib>=1.3.0
|
||||||
|
|
||||||
|
# ML — Deep Learning (CNN pour patterns chandeliers)
|
||||||
|
torch>=2.0.0
|
||||||
|
|
||||||
|
# Visualisation — graphiques financiers et traitement d'images
|
||||||
|
mplfinance>=0.12.10b0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
|||||||
182
docs/CNN_IMAGE_PLAN.md
Normal file
182
docs/CNN_IMAGE_PLAN.md
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
# Phase 4c-bis : CNN Image-Based — Analyse Visuelle de Graphiques
|
||||||
|
|
||||||
|
**Date** : 2026-03-10
|
||||||
|
**Statut** : 🟡 En développement
|
||||||
|
**Prérequis** : Phase 4c (CNN 1D + Ensemble) ✅
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Concept
|
||||||
|
|
||||||
|
Contrairement au CNN 1D qui analyse des séquences numériques, le CNN image-based
|
||||||
|
convertit les bougies OHLCV en **vraies images de graphiques** (rendu matplotlib),
|
||||||
|
puis utilise un réseau de neurones de **vision par ordinateur** (Conv2D) pour
|
||||||
|
reconnaître les patterns visuels — exactement comme un trader devant TradingView.
|
||||||
|
|
||||||
|
### Ce que le modèle apprend sans qu'on le programme
|
||||||
|
- Bougies marteau, étoile filante, doji
|
||||||
|
- Double top, double bottom
|
||||||
|
- Rebonds sur support/résistance (zones de consolidation visibles)
|
||||||
|
- Momentum : grande bougie pleine après consolidation
|
||||||
|
- Divergences visibles (mouvement de prix vs volume en bas de l'image)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
OHLCV DataFrame (64 bougies)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
CandlestickImageRenderer
|
||||||
|
- Rendu mplfinance (sans axes, sans texte)
|
||||||
|
- Bougies vertes (hausse) / rouges (baisse)
|
||||||
|
- Volume en bas (20% de l'image)
|
||||||
|
- Image 128×128 pixels, RGB
|
||||||
|
- Normalisation [0..1]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
CandlestickCNN (Conv2D — vision)
|
||||||
|
Input : (batch, 3, 128, 128)
|
||||||
|
|
||||||
|
Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm + ReLU + MaxPool(2,2) → (32, 64, 64)
|
||||||
|
Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm + ReLU + MaxPool(2,2) → (64, 32, 32)
|
||||||
|
Bloc 3 : Conv2d(64→128, 3×3) + BatchNorm + ReLU + MaxPool(2,2) → (128, 16, 16)
|
||||||
|
Bloc 4 : Conv2d(128→256, 3×3) + BatchNorm + ReLU + AdaptiveAvgPool(1) → (256,)
|
||||||
|
|
||||||
|
Classifieur :
|
||||||
|
Linear(256→128) + Dropout(0.4) + ReLU
|
||||||
|
Linear(128→3) → softmax → [SHORT, NEUTRAL, LONG]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
Signal : 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL)
|
||||||
|
Confidence : max(P_LONG, P_SHORT)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Fichiers à créer
|
||||||
|
|
||||||
|
### src/ml/cnn_image/
|
||||||
|
|
||||||
|
| Fichier | Rôle |
|
||||||
|
|---|---|
|
||||||
|
| `__init__.py` | Export CNNImageStrategyModel |
|
||||||
|
| `chart_renderer.py` | CandlestickImageRenderer : encode(df, seq_len=64) → (N, 3, 128, 128) |
|
||||||
|
| `cnn_image_model.py` | CandlestickCNN(nn.Module) : Conv2D 4-blocs + Dense |
|
||||||
|
| `cnn_image_strategy_model.py` | CNNImageStrategyModel : même interface que MLStrategyModel |
|
||||||
|
|
||||||
|
### src/strategies/cnn_image_driven/
|
||||||
|
|
||||||
|
| Fichier | Rôle |
|
||||||
|
|---|---|
|
||||||
|
| `__init__.py` | Export CNNImageDrivenStrategy |
|
||||||
|
| `cnn_image_strategy.py` | CNNImageDrivenStrategy(BaseStrategy), SL/TP ATR-based |
|
||||||
|
|
||||||
|
### docker/requirements/api.txt
|
||||||
|
- Ajouter `mplfinance>=0.12.10b0`
|
||||||
|
- `Pillow>=10.0.0` (probablement déjà présent)
|
||||||
|
- Rebuild trading-api
|
||||||
|
|
||||||
|
### src/api/routers/trading.py
|
||||||
|
- POST `/trading/train-cnn-image`
|
||||||
|
- GET `/trading/train-cnn-image/{job_id}`
|
||||||
|
- GET `/trading/cnn-image-models`
|
||||||
|
|
||||||
|
### src/ml/ensemble/ensemble_model.py
|
||||||
|
- Ajouter `attach_cnn_image(model)` comme 3ème slot
|
||||||
|
- Mettre à jour `DEFAULT_WEIGHTS = {xgboost: 0.30, cnn: 0.30, cnn_image: 0.40, rl: 0.00}`
|
||||||
|
- Mettre à jour `predict()` pour inclure le 3ème modèle
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CandlestickImageRenderer — Détails
|
||||||
|
|
||||||
|
```python
|
||||||
|
class CandlestickImageRenderer:
|
||||||
|
"""
|
||||||
|
Convertit des données OHLCV en images de graphiques en chandeliers.
|
||||||
|
|
||||||
|
Paramètres d'image :
|
||||||
|
- Taille : 128×128 pixels
|
||||||
|
- Canaux : RGB (3)
|
||||||
|
- Fond : noir (#0d1117)
|
||||||
|
- Hausse : vert (#26a69a), Baisse : rouge (#ef5350)
|
||||||
|
- Volume : dégradé alpha en bas (20% de l'image)
|
||||||
|
- Pas d'axes, pas de labels, pas de titre
|
||||||
|
"""
|
||||||
|
|
||||||
|
def encode(df, seq_len=64) -> np.ndarray:
|
||||||
|
# Retourne (N, 3, 128, 128), float32, normalisé [0,1]
|
||||||
|
# N = len(df) - seq_len + 1 fenêtres glissantes
|
||||||
|
|
||||||
|
def encode_last(df, seq_len=64) -> np.ndarray:
|
||||||
|
# Retourne (1, 3, 128, 128) — dernière fenêtre uniquement
|
||||||
|
# Utilisé pour la prédiction en temps réel
|
||||||
|
|
||||||
|
def _render_single(df_window) -> PIL.Image:
|
||||||
|
# Rendu mplfinance en mémoire (BytesIO)
|
||||||
|
# style custom : fond noir, pas d'axes
|
||||||
|
# Retourne PIL.Image 128×128 RGB
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CNNImageStrategyModel — Interface
|
||||||
|
|
||||||
|
Identique à `MLStrategyModel` et `CNNStrategyModel` :
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||||
|
result = model.train(df_ohlcv) # walk-forward 2 folds
|
||||||
|
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||||
|
model.save() # models/cnn_image_strategy/EURUSD_1h.pt + .json
|
||||||
|
model.load(symbol, timeframe) # chargement depuis disque
|
||||||
|
model.list_trained_models() # liste des modèles disponibles
|
||||||
|
model.get_feature_importance() # retourne [] (CNN = boîte noire)
|
||||||
|
```
|
||||||
|
|
||||||
|
Paramètres entraînement :
|
||||||
|
- `seq_len` : 64 bougies par image
|
||||||
|
- `epochs` : 50 max, early stopping patience=7
|
||||||
|
- `batch_size` : 32
|
||||||
|
- `lr` : 1e-3 (Adam)
|
||||||
|
- Labels : via LabelGenerator partagé (ATR-based, même que XGBoost et CNN 1D)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Intégration Ensemble
|
||||||
|
|
||||||
|
Après cette phase, l'EnsembleModel aura 3 composants :
|
||||||
|
|
||||||
|
```
|
||||||
|
Signal Ensemble =
|
||||||
|
0.30 × XGBoost(TA features)
|
||||||
|
+ 0.30 × CNN 1D(séquences OHLCV)
|
||||||
|
+ 0.40 × CNN Image(graphiques visuels)
|
||||||
|
|
||||||
|
Condition de trade :
|
||||||
|
- Score pondéré ≥ min_confidence (défaut 0.55)
|
||||||
|
- Au moins 2 modèles en accord sur la direction
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## TODOs
|
||||||
|
|
||||||
|
- [ ] chart_renderer.py — CandlestickImageRenderer avec mplfinance
|
||||||
|
- [ ] cnn_image_model.py — CandlestickCNN Conv2D 4-blocs
|
||||||
|
- [ ] cnn_image_strategy_model.py — train/predict/save/load
|
||||||
|
- [ ] cnn_image_strategy.py — CNNImageDrivenStrategy(BaseStrategy)
|
||||||
|
- [ ] trading.py — routes /train-cnn-image, /cnn-image-models
|
||||||
|
- [ ] ensemble_model.py — attach_cnn_image(), poids mis à jour
|
||||||
|
- [ ] requirements — mplfinance + rebuild Docker
|
||||||
|
- [ ] Entraînement validé sur EURUSD/1h
|
||||||
|
- [ ] Intégration EnsembleStrategy avec les 3 modèles
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 4d — RL (après validation CNN image)
|
||||||
|
|
||||||
|
Agent PPO (Proximal Policy Optimization) via `gymnasium` + `stable-baselines3`.
|
||||||
|
Voir docs/CNN_ENSEMBLE_PLAN.md section Phase 4d.
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
# État d'Avancement du Projet — Trading AI Secure
|
# État d'Avancement du Projet — Trading AI Secure
|
||||||
|
|
||||||
**Dernière mise à jour** : 2026-03-08
|
**Dernière mise à jour** : 2026-03-10
|
||||||
**Version** : 0.5.0-beta
|
**Version** : 0.6.0-beta
|
||||||
**Statut Global** : 🟡 En Développement Actif
|
**Statut Global** : 🟡 En Développement Actif
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -15,6 +15,8 @@
|
|||||||
| Phase 3 : Stratégies & Backtesting | ✅ Terminé | 100% |
|
| Phase 3 : Stratégies & Backtesting | ✅ Terminé | 100% |
|
||||||
| Phase 4 : Interface & Dashboard | ✅ Terminé | 100% |
|
| Phase 4 : Interface & Dashboard | ✅ Terminé | 100% |
|
||||||
| Phase 4b : ML-Driven Strategy | ✅ Terminé | 100% |
|
| Phase 4b : ML-Driven Strategy | ✅ Terminé | 100% |
|
||||||
|
| Phase 4c : CNN + Ensemble | ✅ Terminé | 100% |
|
||||||
|
| Phase 4d : RL (PPO) | ⚪ Planifié | 0% |
|
||||||
| Phase 5 : IG Markets (Live) | ⚪ Planifié | 0% |
|
| Phase 5 : IG Markets (Live) | ⚪ Planifié | 0% |
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -121,22 +123,42 @@ Voir [docs/ML_STRATEGY_GUIDE.md](ML_STRATEGY_GUIDE.md) pour la documentation com
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Phase 4c — CNN + Ensemble 🟡 (En cours)
|
## Phase 4c — CNN + Ensemble ✅ (2026-03-10)
|
||||||
|
|
||||||
CNN 1D sur séquences brutes OHLCV + combinaison pondérée avec XGBoost.
|
CNN 1D sur séquences brutes OHLCV + combinaison pondérée avec XGBoost.
|
||||||
Voir [docs/CNN_ENSEMBLE_PLAN.md](CNN_ENSEMBLE_PLAN.md) pour l'architecture complète.
|
Voir [docs/CNN_ENSEMBLE_PLAN.md](CNN_ENSEMBLE_PLAN.md) pour l'architecture complète.
|
||||||
|
|
||||||
| Composant | Fichier | Statut |
|
| Composant | Fichier | Statut |
|
||||||
|---|---|---|
|
|---|---|---|
|
||||||
| PyTorch CPU dans requirements | `docker/requirements/api.txt` | 🟡 |
|
| PyTorch 2.10 dans requirements | `docker/requirements/api.txt` | ✅ Installé |
|
||||||
| CandlestickEncoder (normalisation séquences) | `src/ml/cnn/candlestick_encoder.py` | 🟡 |
|
| CandlestickEncoder (normalisation séquences) | `src/ml/cnn/candlestick_encoder.py` | ✅ |
|
||||||
| CNNModel (1D Conv PyTorch) | `src/ml/cnn/cnn_model.py` | 🟡 |
|
| CNNModel (1D Conv PyTorch) | `src/ml/cnn/cnn_model.py` | ✅ |
|
||||||
| CNNStrategyModel (train/predict/save/load) | `src/ml/cnn/cnn_strategy_model.py` | 🟡 |
|
| CNNStrategyModel (train/predict/save/load) | `src/ml/cnn/cnn_strategy_model.py` | ✅ |
|
||||||
| CNNDrivenStrategy (hérite BaseStrategy) | `src/strategies/cnn_driven/cnn_strategy.py` | 🟡 |
|
| CNNDrivenStrategy (hérite BaseStrategy) | `src/strategies/cnn_driven/cnn_strategy.py` | ✅ |
|
||||||
| Routes API CNN (train, status, list) | `src/api/routers/trading.py` | 🟡 |
|
| Routes API CNN (POST /train-cnn, GET /cnn-models) | `src/api/routers/trading.py` | ✅ |
|
||||||
| EnsembleModel (XGBoost + CNN pondérés) | `src/ml/ensemble/ensemble_model.py` | 🟡 |
|
| EnsembleModel (XGBoost + CNN pondérés) | `src/ml/ensemble/ensemble_model.py` | ✅ |
|
||||||
| EnsembleStrategy (hérite BaseStrategy) | `src/strategies/ensemble/ensemble_strategy.py` | 🟡 |
|
| EnsembleStrategy (hérite BaseStrategy) | `src/strategies/ensemble/ensemble_strategy.py` | ✅ |
|
||||||
| Routes API Ensemble (configure, signal) | `src/api/routers/trading.py` | 🟡 |
|
| Routes API Ensemble (POST /configure, GET /status) | `src/api/routers/trading.py` | ✅ |
|
||||||
|
| Entraînement XGBoost validé (EURUSD/1h) | — | ✅ wf_prec=21.7% |
|
||||||
|
| Entraînement CNN validé (EURUSD/1h) | — | ✅ wf_prec=32.7% |
|
||||||
|
| Backtest comparatif (Scalping vs XGBoost vs CNN vs Ensemble) | — | ⏸️ À faire |
|
||||||
|
|
||||||
|
### Résultats comparatifs des modèles (EURUSD/1h, 2 ans, ~12k barres)
|
||||||
|
| Modèle | WF Accuracy | WF Precision | Labels (L/S/N) |
|
||||||
|
|---|---|---|---|
|
||||||
|
| XGBoost (features TA) | 33.6% | 21.7% | 3992/3943/4338 ✅ équilibré |
|
||||||
|
| CNN 1D (séquences OHLCV) | 31.9% | **32.7%** | 3971/3936/4303 ✅ équilibré |
|
||||||
|
| Ensemble (XGB+CNN 0.40/0.60) | — | — | Non encore testé |
|
||||||
|
|
||||||
|
**Observation** : CNN > XGBoost sur la précision directionnelle (32.7% vs 21.7%). Les deux modèles sont complémentaires pour l'ensemble.
|
||||||
|
|
||||||
|
### Bugs corrigés (2026-03-10 — session agents)
|
||||||
|
- `trading.py` : `_get_data_service()` inexistant → instanciation directe DataService
|
||||||
|
- `trading.py` : `logger` non défini dans except handler → ajout import logging
|
||||||
|
- `trading.py` : `period` string mal converti → period_map identique à _run_optimize_task
|
||||||
|
- `strategy_engine.py` : `ml_driven` non supporté dans `load_strategy()` → cas ajouté
|
||||||
|
- `docker/requirements/api.txt` : dépendances ML (scikit-learn, xgboost, lightgbm) manquantes dans trading-api
|
||||||
|
- `ml_strategy_model.py` : labels [-1,0,1] → encodage +1 pour XGBoost ≥ 2.x
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
236
docs/RL_STRATEGY_GUIDE.md
Normal file
236
docs/RL_STRATEGY_GUIDE.md
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
# RL Strategy Guide — Phase 4d (Reinforcement Learning)
|
||||||
|
|
||||||
|
## Vue d'ensemble
|
||||||
|
|
||||||
|
La Phase 4d introduit une stratégie pilotée par un agent de **Reinforcement Learning (RL)**
|
||||||
|
basé sur l'algorithme **PPO** (Proximal Policy Optimization). Contrairement aux stratégies
|
||||||
|
supervisées (XGBoost, CNN) qui s'entraînent sur des labels pré-calculés, l'agent RL apprend
|
||||||
|
directement par interaction avec un environnement de trading simulé (`TradingEnv`), sans avoir
|
||||||
|
besoin d'annotations explicites LONG/SHORT/NEUTRAL.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Composants (implémentés par l'agent ml/rl/)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/ml/rl/
|
||||||
|
├── trading_env.py # TradingEnv (Gymnasium) — environnement de simulation
|
||||||
|
├── ppo_model.py # PPOModel — réseau Actor-Critic (MLP)
|
||||||
|
└── rl_strategy_model.py # RLStrategyModel — interface train/predict/save/load
|
||||||
|
```
|
||||||
|
|
||||||
|
### Composants (cette phase)
|
||||||
|
|
||||||
|
```
|
||||||
|
src/strategies/rl_driven/
|
||||||
|
├── __init__.py # Export RLDrivenStrategy
|
||||||
|
└── rl_strategy.py # RLDrivenStrategy (hérite BaseStrategy)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pipeline d'entraînement PPO
|
||||||
|
|
||||||
|
```
|
||||||
|
Données OHLCV
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
TradingEnv (Gymnasium)
|
||||||
|
├── Observation space : fenêtre glissante seq_len=20 barres (OHLCV normalisé)
|
||||||
|
├── Action space : {0=HOLD/NEUTRAL, 1=LONG, 2=SHORT}
|
||||||
|
└── Reward : P&L réalisé − pénalité drawdown − pénalité over-trading
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
PPO Actor-Critic (MLP)
|
||||||
|
├── Actor : policy π(a|s) → distribution sur les actions
|
||||||
|
└── Critic : value function V(s) → estimation de la récompense future
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
Optimisation sur total_timesteps pas
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
RLStrategyModel.save()
|
||||||
|
├── models/rl_strategy/EURUSD_1h.zip (politique PPO)
|
||||||
|
└── models/rl_strategy/EURUSD_1h_meta.json
|
||||||
|
```
|
||||||
|
|
||||||
|
### Architecture Actor-Critic
|
||||||
|
|
||||||
|
L'agent PPO utilise un réseau MLP (Multi-Layer Perceptron) à deux têtes :
|
||||||
|
|
||||||
|
- **Actor** : prédit la distribution de probabilité sur les 3 actions (LONG/SHORT/NEUTRAL).
|
||||||
|
La confiance du signal correspond à `max(probas)`.
|
||||||
|
- **Critic** : estime la valeur d'état V(s) pour calculer l'avantage (advantage) utilisé
|
||||||
|
lors de la mise à jour de la politique.
|
||||||
|
|
||||||
|
Hyperparamètres PPO typiques :
|
||||||
|
| Paramètre | Valeur par défaut | Description |
|
||||||
|
|---------------|-------------------|----------------------------------------------|
|
||||||
|
| `gamma` | 0.99 | Facteur d'actualisation des récompenses |
|
||||||
|
| `clip_range` | 0.2 | Clipping ratio PPO (stabilité) |
|
||||||
|
| `n_steps` | 2048 | Nombre de pas par rollout |
|
||||||
|
| `batch_size` | 64 | Taille des mini-batches SGD |
|
||||||
|
| `n_epochs` | 10 | Passes sur chaque rollout |
|
||||||
|
| `ent_coef` | 0.01 | Coefficient d'entropie (exploration) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Lancer l'entraînement
|
||||||
|
|
||||||
|
### Via curl
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Lancer l'entraînement (tâche de fond)
|
||||||
|
curl -X POST http://localhost:8100/trading/train-rl \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"symbol": "EURUSD",
|
||||||
|
"timeframe": "1h",
|
||||||
|
"period": "2y",
|
||||||
|
"total_timesteps": 50000
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Réponse
|
||||||
|
{
|
||||||
|
"job_id": "a1b2c3d4-...",
|
||||||
|
"status": "pending",
|
||||||
|
"symbol": "EURUSD",
|
||||||
|
"timeframe": "1h"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Suivre l'avancement
|
||||||
|
curl http://localhost:8100/trading/train-rl/a1b2c3d4-...
|
||||||
|
|
||||||
|
# Réponse quand terminé
|
||||||
|
{
|
||||||
|
"job_id": "a1b2c3d4-...",
|
||||||
|
"status": "completed",
|
||||||
|
"symbol": "EURUSD",
|
||||||
|
"timeframe": "1h",
|
||||||
|
"avg_reward": 0.042,
|
||||||
|
"sharpe_env": 1.23,
|
||||||
|
"total_timesteps": 50000,
|
||||||
|
"trained_at": "2026-03-10T14:32:00"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Lister les modèles disponibles
|
||||||
|
curl http://localhost:8100/trading/rl-models
|
||||||
|
|
||||||
|
# Réponse
|
||||||
|
{
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"symbol": "EURUSD",
|
||||||
|
"timeframe": "1h",
|
||||||
|
"avg_reward": 0.042,
|
||||||
|
"sharpe_env": 1.23,
|
||||||
|
"total_timesteps": 50000,
|
||||||
|
"trained_at": "2026-03-10T14:32:00"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"count": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Paramètres de la requête
|
||||||
|
|
||||||
|
| Champ | Type | Défaut | Description |
|
||||||
|
|-------------------|-------|----------|---------------------------------------------------------|
|
||||||
|
| `symbol` | str | EURUSD | Paire de trading (ex: EURUSD, BTCUSDT) |
|
||||||
|
| `timeframe` | str | 1h | Timeframe (1m, 5m, 15m, 1h, 4h, 1d) |
|
||||||
|
| `period` | str | 2y | Historique d'entraînement (ex: 6m, 1y, 2y) |
|
||||||
|
| `total_timesteps` | int | 50000 | Nombre total de pas de simulation PPO |
|
||||||
|
|
||||||
|
**Recommandations `total_timesteps` :**
|
||||||
|
- 20 000 : entraînement rapide (test, ~5 min CPU)
|
||||||
|
- 50 000 : entraînement standard (défaut, ~15 min CPU)
|
||||||
|
- 200 000 : entraînement long, meilleure convergence (~1h CPU)
|
||||||
|
- 1 000 000 : entraînement poussé si GPU disponible
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Interpréter les métriques
|
||||||
|
|
||||||
|
### `avg_reward`
|
||||||
|
Récompense moyenne par pas de simulation sur les derniers rollouts d'évaluation.
|
||||||
|
|
||||||
|
- **< 0** : l'agent perd de l'argent en simulation → entraîner plus longtemps ou revoir la fonction de récompense
|
||||||
|
- **0 à 0.02** : agent neutre, légèrement profitable
|
||||||
|
- **> 0.05** : bon signal → tester en backtest réel
|
||||||
|
- **> 0.1** : excellent (attention au sur-apprentissage, vérifier out-of-sample)
|
||||||
|
|
||||||
|
### `sharpe_env`
|
||||||
|
Ratio de Sharpe calculé sur les épisodes de simulation (récompenses / écart-type des récompenses).
|
||||||
|
|
||||||
|
- **< 0.5** : insuffisant pour paper trading
|
||||||
|
- **0.5 – 1.0** : acceptable, à valider en backtest
|
||||||
|
- **> 1.5** : cible pour activation paper trading (conforme aux seuils du projet)
|
||||||
|
|
||||||
|
### Interprétation combinée
|
||||||
|
|
||||||
|
| `avg_reward` | `sharpe_env` | Interprétation |
|
||||||
|
|--------------|--------------|--------------------------------------------------------|
|
||||||
|
| négatif | quelconque | Agent non convergé — relancer avec plus de timesteps |
|
||||||
|
| 0 – 0.02 | < 1.0 | Apprentissage partiel — augmenter total_timesteps |
|
||||||
|
| > 0.03 | > 1.0 | Bon candidat — valider via POST /trading/backtest |
|
||||||
|
| > 0.05 | > 1.5 | Prêt pour paper trading (30 jours minimum) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Intégration avec le paper trading
|
||||||
|
|
||||||
|
Après l'entraînement, si un paper trading avec la stratégie `rl_driven` est actif,
|
||||||
|
le modèle est **automatiquement attaché** sans redémarrage.
|
||||||
|
|
||||||
|
Pour démarrer un paper trading RL :
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8100/trading/paper/start \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"strategy": "rl_driven", "symbol": "EURUSD"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
La stratégie `RLDrivenStrategy` :
|
||||||
|
1. Charge le dernier modèle entraîné pour le symbole/timeframe
|
||||||
|
2. À chaque barre, fournit les 20 dernières bougies à l'agent (`seq_len=20`)
|
||||||
|
3. Si la confiance de l'agent >= `min_confidence` (défaut: 0.55), émet un signal
|
||||||
|
4. SL = `sl_atr_mult × ATR` (défaut: 1×ATR), TP = `tp_atr_mult × ATR` (défaut: 2×ATR)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notes techniques
|
||||||
|
|
||||||
|
### Threads PyTorch
|
||||||
|
L'entraînement fixe `torch.set_num_threads(4)` pour éviter la contention CPU dans
|
||||||
|
le container Docker. Adapter dans `docker-compose.yml` si le container dispose de plus de cœurs.
|
||||||
|
|
||||||
|
### Sauvegarde des modèles
|
||||||
|
Les modèles sont sauvegardés dans `models/rl_strategy/` (volume Docker monté) :
|
||||||
|
- `EURUSD_1h.zip` : politique PPO (format stable-baselines3)
|
||||||
|
- `EURUSD_1h_meta.json` : métadonnées (métriques, hyperparamètres, date)
|
||||||
|
|
||||||
|
### Import conditionnel
|
||||||
|
Le flag `RL_AVAILABLE` est `False` si PyTorch ou stable-baselines3 ne sont pas
|
||||||
|
installés. La stratégie dégrade gracieusement (aucun signal, aucune exception).
|
||||||
|
Pour activer : ajouter `stable-baselines3` et `gymnasium` dans `docker/requirements/api.txt`
|
||||||
|
puis reconstruire le container (`docker compose build --no-cache trading-api`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Seuils de validation (conforme au projet)
|
||||||
|
|
||||||
|
Avant activation du live trading, la stratégie RL doit satisfaire :
|
||||||
|
|
||||||
|
| Métrique | Seuil minimum |
|
||||||
|
|-----------------|---------------|
|
||||||
|
| Sharpe Ratio | ≥ 1.5 |
|
||||||
|
| Max Drawdown | ≤ 10% |
|
||||||
|
| Win Rate | ≥ 55% |
|
||||||
|
| Paper Trading | ≥ 30 jours |
|
||||||
|
|
||||||
|
Ces seuils s'appliquent au **backtest out-of-sample** et au **paper trading**, pas
|
||||||
|
aux métriques de simulation RL (`sharpe_env`) qui sont indicatives uniquement.
|
||||||
365
scripts/compare_strategies.py
Executable file
365
scripts/compare_strategies.py
Executable file
@@ -0,0 +1,365 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
compare_strategies.py — Backtest comparatif automatisé des stratégies de trading.
|
||||||
|
|
||||||
|
Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven)
|
||||||
|
sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau
|
||||||
|
comparatif et sauvegarde le rapport JSON dans reports/.
|
||||||
|
|
||||||
|
Usage :
|
||||||
|
python scripts/compare_strategies.py
|
||||||
|
python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vérification de la dépendance httpx (ou fallback vers requests)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
_HTTP_BACKEND = "httpx"
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import requests as _requests_module
|
||||||
|
_HTTP_BACKEND = "requests"
|
||||||
|
except ImportError:
|
||||||
|
print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
API_BASE_URL = "http://localhost:8100"
|
||||||
|
POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes)
|
||||||
|
TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes)
|
||||||
|
REPORTS_DIR = Path(__file__).parent.parent / "reports"
|
||||||
|
|
||||||
|
# Stratégies à comparer — dans l'ordre d'affichage.
|
||||||
|
# NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven
|
||||||
|
# La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest
|
||||||
|
# (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement
|
||||||
|
# si l'API évolue pour la supporter.
|
||||||
|
STRATEGIES = [
|
||||||
|
"scalping",
|
||||||
|
"ml_driven",
|
||||||
|
"cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement
|
||||||
|
]
|
||||||
|
|
||||||
|
# Colonnes métriques à collecter et afficher
|
||||||
|
METRICS_COLS = [
|
||||||
|
("total_return", "Retour total", "{:.2%}"),
|
||||||
|
("sharpe_ratio", "Sharpe ratio", "{:.3f}"),
|
||||||
|
("max_drawdown", "Max drawdown", "{:.2%}"),
|
||||||
|
("win_rate", "Win rate", "{:.2%}"),
|
||||||
|
("total_trades", "Trades totaux", "{:d}"),
|
||||||
|
("profit_factor", "Profit factor", "{:.2f}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Seuils de validation (identiques à ceux de l'API)
|
||||||
|
SEUIL_SHARPE = 1.5
|
||||||
|
SEUIL_DRAWDOWN_MAX = 0.10
|
||||||
|
SEUIL_WIN_RATE = 0.55
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Couche HTTP (httpx ou requests selon disponibilité)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _post(url: str, payload: dict, timeout: int = 30) -> dict:
|
||||||
|
"""Effectue un POST JSON et retourne la réponse parsée."""
|
||||||
|
if _HTTP_BACKEND == "httpx":
|
||||||
|
resp = httpx.post(url, json=payload, timeout=timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
else:
|
||||||
|
resp = _requests_module.post(url, json=payload, timeout=timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _get(url: str, timeout: int = 15) -> dict:
|
||||||
|
"""Effectue un GET et retourne la réponse parsée."""
|
||||||
|
if _HTTP_BACKEND == "httpx":
|
||||||
|
resp = httpx.get(url, timeout=timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
else:
|
||||||
|
resp = _requests_module.get(url, timeout=timeout)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fonctions backtest
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str:
|
||||||
|
"""
|
||||||
|
Lance un job de backtest via POST /trading/backtest.
|
||||||
|
|
||||||
|
Retourne le job_id ou lève une exception en cas d'erreur.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"strategy": strategy,
|
||||||
|
"symbol": symbol,
|
||||||
|
"period": period,
|
||||||
|
"initial_capital": capital,
|
||||||
|
}
|
||||||
|
url = f"{API_BASE_URL}/trading/backtest"
|
||||||
|
data = _post(url, payload)
|
||||||
|
return data["job_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def attendre_resultat(job_id: str, strategy: str) -> dict:
|
||||||
|
"""
|
||||||
|
Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes
|
||||||
|
jusqu'à ce que le statut soit 'completed' ou 'failed'.
|
||||||
|
|
||||||
|
Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout.
|
||||||
|
"""
|
||||||
|
url = f"{API_BASE_URL}/trading/backtest/{job_id}"
|
||||||
|
debut = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
elapsed = time.time() - debut
|
||||||
|
if elapsed > TIMEOUT_SEC:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours."
|
||||||
|
)
|
||||||
|
|
||||||
|
data = _get(url)
|
||||||
|
statut = data.get("status", "?")
|
||||||
|
|
||||||
|
print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)")
|
||||||
|
|
||||||
|
if statut == "completed":
|
||||||
|
return data
|
||||||
|
elif statut == "failed":
|
||||||
|
erreur = data.get("error", "raison inconnue")
|
||||||
|
raise RuntimeError(f"[{strategy}] Job échoué : {erreur}")
|
||||||
|
|
||||||
|
# Toujours en cours — on attend
|
||||||
|
time.sleep(POLL_INTERVAL_SEC)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Affichage du tableau comparatif
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def afficher_tableau(resultats: list[dict]) -> None:
|
||||||
|
"""
|
||||||
|
Affiche un tableau comparatif dans le terminal.
|
||||||
|
Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple.
|
||||||
|
"""
|
||||||
|
# Construction des données tabulaires
|
||||||
|
headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"]
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for res in resultats:
|
||||||
|
nom = res["strategy"]
|
||||||
|
ligne = [nom]
|
||||||
|
for key, _, fmt in METRICS_COLS:
|
||||||
|
val = res.get(key)
|
||||||
|
if val is None:
|
||||||
|
ligne.append("N/A")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if key == "total_trades":
|
||||||
|
ligne.append(fmt.format(int(val)))
|
||||||
|
else:
|
||||||
|
ligne.append(fmt.format(float(val)))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
ligne.append(str(val))
|
||||||
|
# Indicateur de validation
|
||||||
|
valide = res.get("is_valid_for_paper")
|
||||||
|
if valide is True:
|
||||||
|
ligne.append("OUI")
|
||||||
|
elif valide is False:
|
||||||
|
ligne.append("NON")
|
||||||
|
else:
|
||||||
|
ligne.append("?")
|
||||||
|
rows.append(ligne)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
print("RAPPORT COMPARATIF DES STRATÉGIES")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||||
|
except ImportError:
|
||||||
|
# Fallback : affichage simple sans tabulate
|
||||||
|
col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||||
|
for i, h in enumerate(headers)]
|
||||||
|
sep = " ".join("-" * w for w in col_widths)
|
||||||
|
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||||
|
print(header_line)
|
||||||
|
print(sep)
|
||||||
|
for row in rows:
|
||||||
|
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Identification de la meilleure stratégie (selon Sharpe ratio)
|
||||||
|
valides = [r for r in resultats if r.get("sharpe_ratio") is not None]
|
||||||
|
if valides:
|
||||||
|
meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999)
|
||||||
|
print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} "
|
||||||
|
f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})")
|
||||||
|
|
||||||
|
# Rappel des seuils
|
||||||
|
print()
|
||||||
|
print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | "
|
||||||
|
f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | "
|
||||||
|
f"Win rate >= {SEUIL_WIN_RATE:.0%}")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sauvegarde du rapport JSON
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path:
|
||||||
|
"""
|
||||||
|
Sauvegarde le rapport de comparaison en JSON dans reports/.
|
||||||
|
|
||||||
|
Retourne le chemin du fichier créé.
|
||||||
|
"""
|
||||||
|
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
horodatage = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
nom_fichier = f"backtest_comparison_{horodatage}.json"
|
||||||
|
chemin = REPORTS_DIR / nom_fichier
|
||||||
|
|
||||||
|
rapport = {
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"parametres": {
|
||||||
|
"symbol": symbol,
|
||||||
|
"period": period,
|
||||||
|
"api_base_url": API_BASE_URL,
|
||||||
|
},
|
||||||
|
"seuils_validation": {
|
||||||
|
"sharpe_ratio_min": SEUIL_SHARPE,
|
||||||
|
"max_drawdown_max": SEUIL_DRAWDOWN_MAX,
|
||||||
|
"win_rate_min": SEUIL_WIN_RATE,
|
||||||
|
},
|
||||||
|
"resultats": resultats,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(chemin, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(rapport, f, indent=2, ensure_ascii=False, default=str)
|
||||||
|
|
||||||
|
return chemin
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Point d'entrée principal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse les arguments en ligne de commande."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Compare les stratégies de trading via backtest API."
|
||||||
|
)
|
||||||
|
parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)")
|
||||||
|
parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)")
|
||||||
|
parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--strategies",
|
||||||
|
nargs="+",
|
||||||
|
default=STRATEGIES,
|
||||||
|
help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
print("BACKTEST COMPARATIF DES STRATÉGIES")
|
||||||
|
print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR")
|
||||||
|
print(f"Stratégies: {', '.join(args.strategies)}")
|
||||||
|
print(f"API: {API_BASE_URL}")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# Vérification de la disponibilité de l'API
|
||||||
|
try:
|
||||||
|
_get(f"{API_BASE_URL}/health")
|
||||||
|
print("API disponible.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ERREUR : Impossible de joindre l'API ({e})")
|
||||||
|
print("Vérifiez que le container trading-api tourne sur le port 8100.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Phase 1 : lancement de tous les backtests
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
jobs: dict[str, str] = {} # {strategy: job_id}
|
||||||
|
stratégies_échouées: list[str] = []
|
||||||
|
|
||||||
|
for strat in args.strategies:
|
||||||
|
print(f"Lancement backtest [{strat}]...")
|
||||||
|
try:
|
||||||
|
job_id = lancer_backtest(strat, args.symbol, args.period, args.capital)
|
||||||
|
jobs[strat] = job_id
|
||||||
|
print(f" -> Job ID : {job_id}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERREUR lancement [{strat}] : {e}")
|
||||||
|
stratégies_échouées.append(strat)
|
||||||
|
|
||||||
|
if not jobs:
|
||||||
|
print("Aucun job lancé. Arrêt.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Phase 2 : attente et collecte des résultats (séquentielle)
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
resultats: list[dict] = []
|
||||||
|
|
||||||
|
for strat, job_id in jobs.items():
|
||||||
|
print(f"Attente résultat [{strat}] (job={job_id})...")
|
||||||
|
try:
|
||||||
|
res = attendre_resultat(job_id, strat)
|
||||||
|
resultats.append(res)
|
||||||
|
print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, "
|
||||||
|
f"Retour={res.get('total_return', 'N/A')}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERREUR [{strat}] : {e}")
|
||||||
|
# On ajoute quand même un résultat partiel pour la lisibilité du rapport
|
||||||
|
resultats.append({
|
||||||
|
"strategy": strat,
|
||||||
|
"symbol": args.symbol,
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
# Phase 3 : affichage et sauvegarde
|
||||||
|
# -----------------------------------------------------------------------
|
||||||
|
afficher_tableau(resultats)
|
||||||
|
|
||||||
|
chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period)
|
||||||
|
print(f"Rapport JSON sauvegardé : {chemin_rapport}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Résumé des échecs éventuels
|
||||||
|
if stratégies_échouées:
|
||||||
|
print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
327
scripts/quick_benchmark.py
Executable file
327
scripts/quick_benchmark.py
Executable file
@@ -0,0 +1,327 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
quick_benchmark.py — Benchmark rapide des modèles ML entraînés.
|
||||||
|
|
||||||
|
Lit les fichiers *_meta.json dans les répertoires de modèles :
|
||||||
|
- models/ml_strategy/ (XGBoost / LightGBM / RandomForest)
|
||||||
|
- models/cnn_strategy/ (CNN 1D — CandlestickEncoder)
|
||||||
|
- models/cnn_image_strategy/ (CNN Image — chandeliers en image)
|
||||||
|
|
||||||
|
Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc.
|
||||||
|
Indique quel modèle obtient la meilleure précision walk-forward.
|
||||||
|
|
||||||
|
Usage :
|
||||||
|
python scripts/quick_benchmark.py
|
||||||
|
python scripts/quick_benchmark.py --models-root /chemin/vers/models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Répertoires de modèles (relatifs à la racine du projet)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
# Mapping : nom affiché -> chemin relatif au projet
|
||||||
|
MODEL_DIRS = {
|
||||||
|
"ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy",
|
||||||
|
"CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy",
|
||||||
|
"CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Colonnes de métriques affichées dans le tableau
|
||||||
|
COLONNES = [
|
||||||
|
("type", "Type", "{:<28}"),
|
||||||
|
("symbol", "Symbol", "{:<8}"),
|
||||||
|
("timeframe", "TF", "{:<5}"),
|
||||||
|
("model_type", "Modèle", "{:<12}"),
|
||||||
|
("n_samples", "Échantillons", "{:>12}"),
|
||||||
|
("wf_accuracy", "WF Accuracy", "{:>12}"),
|
||||||
|
("wf_precision", "WF Precision", "{:>13}"),
|
||||||
|
("trained_at", "Entraîné le", "{:<20}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lecture des méta-fichiers JSON
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Parcourt les répertoires de modèles et charge les métadonnées JSON.
|
||||||
|
|
||||||
|
Retourne une liste de dicts prêts pour l'affichage.
|
||||||
|
"""
|
||||||
|
tous = []
|
||||||
|
|
||||||
|
for type_label, dossier in model_dirs.items():
|
||||||
|
if not dossier.exists():
|
||||||
|
# Répertoire absent — pas encore de modèles entraînés pour ce type
|
||||||
|
continue
|
||||||
|
|
||||||
|
fichiers_meta = sorted(dossier.glob("*_meta.json"))
|
||||||
|
if not fichiers_meta:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for f in fichiers_meta:
|
||||||
|
try:
|
||||||
|
with open(f, encoding="utf-8") as fp:
|
||||||
|
meta = json.load(fp)
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
wf = meta.get("wf_metrics", {})
|
||||||
|
|
||||||
|
# Extraction des champs — avec valeurs par défaut
|
||||||
|
# Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé
|
||||||
|
# comme fallback si les champs ne sont pas dans le JSON.
|
||||||
|
stem_parts = f.stem.replace("_meta", "").split("_")
|
||||||
|
symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?")
|
||||||
|
timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?")
|
||||||
|
model_type = meta.get(
|
||||||
|
"model_type",
|
||||||
|
stem_parts[2] if len(stem_parts) > 2 else dossier.name
|
||||||
|
)
|
||||||
|
|
||||||
|
wf_accuracy = wf.get("avg_accuracy", None)
|
||||||
|
wf_precision = wf.get("avg_precision", None)
|
||||||
|
|
||||||
|
trained_at = meta.get("trained_at", "?")
|
||||||
|
# Tronque la date ISO à 19 caractères pour la lisibilité
|
||||||
|
if isinstance(trained_at, str) and len(trained_at) > 19:
|
||||||
|
trained_at = trained_at[:19]
|
||||||
|
|
||||||
|
tous.append({
|
||||||
|
"type": type_label,
|
||||||
|
"symbol": symbol,
|
||||||
|
"timeframe": timeframe,
|
||||||
|
"model_type": model_type,
|
||||||
|
"n_samples": meta.get("n_samples", 0),
|
||||||
|
"wf_accuracy": wf_accuracy,
|
||||||
|
"wf_precision": wf_precision,
|
||||||
|
"trained_at": trained_at,
|
||||||
|
"_meta_path": str(f), # pour le rapport détaillé éventuel
|
||||||
|
"_raw_meta": meta,
|
||||||
|
})
|
||||||
|
|
||||||
|
return tous
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Affichage du tableau
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def formater_val(val, fmt: str) -> str:
|
||||||
|
"""Formate une valeur numérique ou retourne 'N/A'."""
|
||||||
|
if val is None:
|
||||||
|
return "N/A"
|
||||||
|
try:
|
||||||
|
if "%" in fmt or "f" in fmt:
|
||||||
|
return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val))
|
||||||
|
return fmt.format(val)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return str(val)
|
||||||
|
|
||||||
|
|
||||||
|
def afficher_tableau(modeles: list[dict]) -> None:
|
||||||
|
"""Affiche le tableau comparatif des modèles dans le terminal."""
|
||||||
|
if not modeles:
|
||||||
|
print("Aucun modèle entraîné trouvé.")
|
||||||
|
print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Entêtes
|
||||||
|
headers = [label for _, label, _ in COLONNES]
|
||||||
|
col_fmts = [fmt for _, _, fmt in COLONNES]
|
||||||
|
|
||||||
|
# Construction des lignes
|
||||||
|
rows = []
|
||||||
|
for m in modeles:
|
||||||
|
ligne = []
|
||||||
|
for key, _, fmt in COLONNES:
|
||||||
|
val = m.get(key)
|
||||||
|
if val is None:
|
||||||
|
ligne.append("N/A")
|
||||||
|
elif key in ("wf_accuracy", "wf_precision"):
|
||||||
|
# Affichage en pourcentage
|
||||||
|
try:
|
||||||
|
ligne.append(f"{float(val):.2%}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
ligne.append("N/A")
|
||||||
|
elif key == "n_samples":
|
||||||
|
try:
|
||||||
|
ligne.append(f"{int(val):,}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
ligne.append(str(val))
|
||||||
|
else:
|
||||||
|
ligne.append(str(val))
|
||||||
|
rows.append(ligne)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 80)
|
||||||
|
print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||||
|
except ImportError:
|
||||||
|
# Fallback : affichage simple
|
||||||
|
col_widths = [
|
||||||
|
max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||||
|
for i, h in enumerate(headers)
|
||||||
|
]
|
||||||
|
sep = " ".join("-" * w for w in col_widths)
|
||||||
|
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||||
|
print(header_line)
|
||||||
|
print(sep)
|
||||||
|
for row in rows:
|
||||||
|
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Identification du meilleur modèle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def identifier_meilleur(modeles: list[dict]) -> None:
|
||||||
|
"""Indique quel modèle est le meilleur selon wf_accuracy et wf_precision."""
|
||||||
|
candidats = [m for m in modeles if m.get("wf_accuracy") is not None]
|
||||||
|
if not candidats:
|
||||||
|
print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Critère principal : wf_accuracy ; critère secondaire : wf_precision
|
||||||
|
meilleur = max(
|
||||||
|
candidats,
|
||||||
|
key=lambda m: (
|
||||||
|
m.get("wf_accuracy") or 0.0,
|
||||||
|
m.get("wf_precision") or 0.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Meilleur modèle (WF Accuracy) :")
|
||||||
|
print(f" Type : {meilleur['type']}")
|
||||||
|
print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}")
|
||||||
|
print(f" Modèle : {meilleur['model_type']}")
|
||||||
|
|
||||||
|
acc = meilleur.get("wf_accuracy")
|
||||||
|
prec = meilleur.get("wf_precision")
|
||||||
|
print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A")
|
||||||
|
print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A")
|
||||||
|
print(f" Entraîné le : {meilleur.get('trained_at', '?')}")
|
||||||
|
print(f" Fichier meta : {meilleur['_meta_path']}")
|
||||||
|
|
||||||
|
# Avertissement si la précision est insuffisante pour le trading
|
||||||
|
if acc is not None and acc < 0.40:
|
||||||
|
print()
|
||||||
|
print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.")
|
||||||
|
print(" Envisagez un re-entraînement avec plus de données ou de features.")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Statistiques globales par type de modèle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def afficher_stats_globales(modeles: list[dict]) -> None:
|
||||||
|
"""Affiche des statistiques agrégées par type de modèle."""
|
||||||
|
if not modeles:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Regroupement par type
|
||||||
|
par_type: dict[str, list] = {}
|
||||||
|
for m in modeles:
|
||||||
|
par_type.setdefault(m["type"], []).append(m)
|
||||||
|
|
||||||
|
print("Résumé par type de modèle :")
|
||||||
|
print("-" * 50)
|
||||||
|
for type_label, groupe in par_type.items():
|
||||||
|
accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None]
|
||||||
|
precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None]
|
||||||
|
n_tot = sum(m.get("n_samples", 0) for m in groupe)
|
||||||
|
|
||||||
|
acc_moy = sum(accs) / len(accs) if accs else None
|
||||||
|
prec_moy = sum(precs) / len(precs) if precs else None
|
||||||
|
|
||||||
|
acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A"
|
||||||
|
prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A"
|
||||||
|
|
||||||
|
print(f" {type_label}")
|
||||||
|
print(f" Modèles entraînés : {len(groupe)}")
|
||||||
|
print(f" WF Accuracy moy. : {acc_str}")
|
||||||
|
print(f" WF Precision moy. : {prec_str}")
|
||||||
|
print(f" Échantillons tot. : {n_tot:,}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Point d'entrée principal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse les arguments en ligne de commande."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models-root",
|
||||||
|
type=Path,
|
||||||
|
default=PROJECT_ROOT / "models",
|
||||||
|
help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Mise à jour des chemins si --models-root est spécifié
|
||||||
|
dirs_effectifs = {
|
||||||
|
label: args.models_root / dossier.name
|
||||||
|
for label, dossier in MODEL_DIRS.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("Lecture des modèles dans :")
|
||||||
|
for label, chemin in dirs_effectifs.items():
|
||||||
|
existe = "OK" if chemin.exists() else "absent"
|
||||||
|
print(f" [{existe}] {chemin}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Chargement de tous les modèles
|
||||||
|
modeles = charger_modeles(dirs_effectifs)
|
||||||
|
|
||||||
|
if not modeles:
|
||||||
|
print("Aucun modèle trouvé dans les répertoires indiqués.")
|
||||||
|
print()
|
||||||
|
print("Pour entraîner un modèle XGBoost/LightGBM :")
|
||||||
|
print(" POST http://localhost:8100/trading/train")
|
||||||
|
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}")
|
||||||
|
print()
|
||||||
|
print("Pour entraîner un modèle CNN 1D :")
|
||||||
|
print(" POST http://localhost:8100/trading/train-cnn")
|
||||||
|
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Affichage du tableau
|
||||||
|
afficher_tableau(modeles)
|
||||||
|
|
||||||
|
# Statistiques par type
|
||||||
|
afficher_stats_globales(modeles)
|
||||||
|
|
||||||
|
# Meilleur modèle
|
||||||
|
identifier_meilleur(modeles)
|
||||||
|
|
||||||
|
print(f"Total : {len(modeles)} modèle(s) trouvé(s).")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -4,11 +4,14 @@ Routes de trading : risk, positions, signaux, backtesting, paper trading.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -259,17 +262,46 @@ def get_ml_status(symbol: str = "EURUSD"):
|
|||||||
return cached["result"]
|
return cached["result"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from src.ml.regime_detector import RegimeDetector
|
||||||
|
|
||||||
config = ConfigLoader.load_all()
|
config = ConfigLoader.load_all()
|
||||||
data_service = DataService(config)
|
data_service = DataService(config)
|
||||||
|
|
||||||
|
timeframe = "1h"
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
start = now - timedelta(days=30)
|
start = now - timedelta(days=30)
|
||||||
|
|
||||||
# Récupérer données synchrones via asyncio.run
|
# ------------------------------------------------------------------
|
||||||
|
# Tentative de chargement du modèle HMM persisté (évite le re-train)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
detector = RegimeDetector(n_regimes=4)
|
||||||
|
hmm_loaded = detector.load(symbol, timeframe)
|
||||||
|
|
||||||
|
if hmm_loaded and not detector.needs_retrain(max_age_hours=24):
|
||||||
|
# Modèle récent disponible sur disque — pas besoin de ré-entraîner
|
||||||
|
logger.info(
|
||||||
|
f"Modèle HMM chargé depuis le disque pour {symbol}/{timeframe} "
|
||||||
|
f"(entraîné le {detector._trained_at})"
|
||||||
|
)
|
||||||
|
need_fit = False
|
||||||
|
else:
|
||||||
|
if hmm_loaded:
|
||||||
|
logger.info(
|
||||||
|
f"Modèle HMM trop ancien pour {symbol}/{timeframe} — ré-entraînement nécessaire"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Aucun modèle HMM persisté pour {symbol}/{timeframe} — entraînement initial"
|
||||||
|
)
|
||||||
|
need_fit = True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Récupération des données (toujours nécessaire pour la prédiction)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
df = asyncio.run(
|
df = asyncio.run(
|
||||||
data_service.get_historical_data(
|
data_service.get_historical_data(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
timeframe="1h",
|
timeframe=timeframe,
|
||||||
start_date=start,
|
start_date=start,
|
||||||
end_date=now,
|
end_date=now,
|
||||||
)
|
)
|
||||||
@@ -288,11 +320,23 @@ def get_ml_status(symbol: str = "EURUSD"):
|
|||||||
|
|
||||||
df.columns = [c.lower() for c in df.columns]
|
df.columns = [c.lower() for c in df.columns]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Entraînement si nécessaire puis sauvegarde du modèle
|
||||||
|
# ------------------------------------------------------------------
|
||||||
ml = MLEngine(config=config.get("ml", {}))
|
ml = MLEngine(config=config.get("ml", {}))
|
||||||
ml.initialize(df)
|
|
||||||
|
if need_fit:
|
||||||
|
# Entraîner le détecteur et l'injecter dans MLEngine
|
||||||
|
detector.fit(df)
|
||||||
|
# Sauvegarder le nouveau modèle sur disque pour les prochains appels
|
||||||
|
detector.save(symbol, timeframe)
|
||||||
|
|
||||||
|
# Injecter le détecteur (chargé ou fraîchement entraîné) dans MLEngine
|
||||||
|
ml.regime_detector = detector
|
||||||
|
ml.current_regime = detector.predict_current_regime(df)
|
||||||
|
|
||||||
regime_info = ml.get_regime_info()
|
regime_info = ml.get_regime_info()
|
||||||
regime_stats = ml.regime_detector.get_regime_statistics(df)
|
regime_stats = detector.get_regime_statistics(df)
|
||||||
|
|
||||||
strategy_advice = {
|
strategy_advice = {
|
||||||
s: ml.should_trade(s)
|
s: ml.should_trade(s)
|
||||||
@@ -412,8 +456,8 @@ async def run_backtest(request: BacktestRequest, background_tasks: BackgroundTas
|
|||||||
Lance un backtest en arrière-plan et retourne un `job_id`.
|
Lance un backtest en arrière-plan et retourne un `job_id`.
|
||||||
Interroger `/trading/backtest/{job_id}` pour le résultat.
|
Interroger `/trading/backtest/{job_id}` pour le résultat.
|
||||||
"""
|
"""
|
||||||
if request.strategy not in ("scalping", "intraday", "swing"):
|
if request.strategy not in ("scalping", "intraday", "swing", "ml_driven"):
|
||||||
raise HTTPException(400, detail="strategy doit être : scalping | intraday | swing")
|
raise HTTPException(400, detail="strategy doit être : scalping | intraday | swing | ml_driven")
|
||||||
|
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
_backtest_jobs[job_id] = {
|
_backtest_jobs[job_id] = {
|
||||||
@@ -776,11 +820,23 @@ async def _run_train_task(job_id: str, request: TrainRequest) -> None:
|
|||||||
_train_jobs[job_id]["status"] = "running"
|
_train_jobs[job_id]["status"] = "running"
|
||||||
try:
|
try:
|
||||||
# Récupération des données historiques
|
# Récupération des données historiques
|
||||||
data_service = _get_data_service()
|
from src.data.data_service import DataService
|
||||||
|
from src.utils.config_loader import ConfigLoader
|
||||||
|
from datetime import timedelta
|
||||||
|
config = ConfigLoader.load_all()
|
||||||
|
data_service = DataService(config)
|
||||||
|
|
||||||
|
end_date = datetime.now()
|
||||||
|
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||||
|
unit = request.period[-1]
|
||||||
|
value = int(request.period[:-1])
|
||||||
|
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||||
|
|
||||||
df = await data_service.get_historical_data(
|
df = await data_service.get_historical_data(
|
||||||
symbol = request.symbol,
|
symbol = request.symbol,
|
||||||
timeframe = request.timeframe,
|
timeframe = request.timeframe,
|
||||||
period = request.period,
|
start_date = start_date,
|
||||||
|
end_date = end_date,
|
||||||
)
|
)
|
||||||
if df is None or len(df) < 200:
|
if df is None or len(df) < 200:
|
||||||
raise ValueError(f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)")
|
raise ValueError(f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)")
|
||||||
@@ -917,3 +973,876 @@ def get_feature_importance(symbol: str, timeframe: str, model_type: str = "xgboo
|
|||||||
raise HTTPException(404, detail=f"Modèle non trouvé pour {symbol}/{timeframe}/{model_type}")
|
raise HTTPException(404, detail=f"Modèle non trouvé pour {symbol}/{timeframe}/{model_type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(500, detail=str(e))
|
raise HTTPException(500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CNN STRATEGY — Entraînement et gestion des modèles CNN-Driven
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
CNN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CNN_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.strategies.cnn_driven import CNNDrivenStrategy
|
||||||
|
from src.strategies.ensemble import EnsembleStrategy
|
||||||
|
from src.ml.ensemble import EnsembleModel
|
||||||
|
ENSEMBLE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
ENSEMBLE_AVAILABLE = False
|
||||||
|
|
||||||
|
# Stockage en mémoire des jobs d'entraînement CNN
|
||||||
|
_cnn_train_jobs: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class CNNTrainRequest(BaseModel):
|
||||||
|
"""Requête d'entraînement du modèle CNN."""
|
||||||
|
symbol: str = "EURUSD"
|
||||||
|
timeframe: str = "1h"
|
||||||
|
period: str = "2y"
|
||||||
|
seq_len: int = 64
|
||||||
|
tp_atr_mult: float = 2.0
|
||||||
|
sl_atr_mult: float = 1.0
|
||||||
|
horizon: int = 30
|
||||||
|
min_confidence: float = 0.55
|
||||||
|
|
||||||
|
|
||||||
|
class CNNTrainResponse(BaseModel):
|
||||||
|
"""Réponse d'un job d'entraînement CNN."""
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
symbol: str
|
||||||
|
timeframe: str
|
||||||
|
wf_accuracy: Optional[float] = None
|
||||||
|
wf_precision: Optional[float] = None
|
||||||
|
label_dist: Optional[dict] = None
|
||||||
|
n_samples: Optional[int] = None
|
||||||
|
trained_at: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_cnn_train_task(job_id: str, request: CNNTrainRequest) -> None:
|
||||||
|
"""Tâche d'entraînement CNN exécutée en arrière-plan."""
|
||||||
|
_cnn_train_jobs[job_id]["status"] = "running"
|
||||||
|
try:
|
||||||
|
from src.data.data_service import DataService
|
||||||
|
from src.utils.config_loader import ConfigLoader
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = ConfigLoader.load_all()
|
||||||
|
data_service = DataService(config)
|
||||||
|
|
||||||
|
end_date = datetime.now()
|
||||||
|
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||||
|
unit = request.period[-1]
|
||||||
|
value = int(request.period[:-1])
|
||||||
|
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||||
|
|
||||||
|
df = await data_service.get_historical_data(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
start_date = start_date,
|
||||||
|
end_date = end_date,
|
||||||
|
)
|
||||||
|
if df is None or len(df) < 200:
|
||||||
|
raise ValueError(f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)")
|
||||||
|
|
||||||
|
# Entraînement dans un thread (opération CPU-bound)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, _sync_cnn_train, df, request)
|
||||||
|
|
||||||
|
_cnn_train_jobs[job_id].update({
|
||||||
|
"status": "completed",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
"n_samples": result.get("n_samples"),
|
||||||
|
"wf_accuracy": result.get("wf_metrics", {}).get("avg_accuracy"),
|
||||||
|
"wf_precision": result.get("wf_metrics", {}).get("avg_precision"),
|
||||||
|
"label_dist": result.get("label_dist"),
|
||||||
|
"trained_at": result.get("trained_at"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Auto-attachement à la stratégie CNN active si elle existe
|
||||||
|
_attach_cnn_model_to_strategy(request)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Erreur entraînement CNN job {job_id} : {exc}", exc_info=True)
|
||||||
|
_cnn_train_jobs[job_id]["status"] = "failed"
|
||||||
|
_cnn_train_jobs[job_id]["error"] = str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_cnn_train(df, request: CNNTrainRequest) -> dict:
|
||||||
|
"""Wrapper synchrone pour CNNStrategyModel.train() (exécuté dans un thread)."""
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
model = CNNStrategyModel(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
seq_len = request.seq_len,
|
||||||
|
tp_atr_mult = request.tp_atr_mult,
|
||||||
|
sl_atr_mult = request.sl_atr_mult,
|
||||||
|
horizon = request.horizon,
|
||||||
|
min_confidence = request.min_confidence,
|
||||||
|
)
|
||||||
|
return model.train(df)
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_cnn_model_to_strategy(request: CNNTrainRequest) -> None:
|
||||||
|
"""Attache le modèle CNN entraîné à la stratégie cnn_driven active (paper trading)."""
|
||||||
|
try:
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
from src.strategies.cnn_driven import CNNDrivenStrategy
|
||||||
|
|
||||||
|
engine = _paper_state.get("engine")
|
||||||
|
if engine and hasattr(engine, 'strategy_engine'):
|
||||||
|
strat = engine.strategy_engine.strategies.get('cnn_driven')
|
||||||
|
if strat and isinstance(strat, CNNDrivenStrategy):
|
||||||
|
model = CNNStrategyModel.load(request.symbol, request.timeframe)
|
||||||
|
strat.attach_model(model)
|
||||||
|
logger.info("Modèle CNN attaché à la stratégie cnn_driven active")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Auto-attach modèle CNN ignoré : {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train-cnn", response_model=CNNTrainResponse, summary="Entraîner le modèle CNN")
|
||||||
|
async def train_cnn_model(request: CNNTrainRequest, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
Lance l'entraînement du modèle CNN en arrière-plan.
|
||||||
|
|
||||||
|
Le CNN 1D apprend directement les patterns visuels dans les séquences
|
||||||
|
OHLCV brutes (double bottom, squeeze Bollinger, alignements...).
|
||||||
|
|
||||||
|
- Retourne un `job_id` à interroger via `GET /trading/train-cnn/{job_id}`
|
||||||
|
- Le modèle est sauvegardé sur disque après entraînement
|
||||||
|
- Si un paper trading CNN est actif, le modèle lui est automatiquement attaché
|
||||||
|
"""
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
raise HTTPException(503, detail="PyTorch requis — rebuilder le container trading-api")
|
||||||
|
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
_cnn_train_jobs[job_id] = {
|
||||||
|
"status": "pending",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
}
|
||||||
|
|
||||||
|
background_tasks.add_task(_run_cnn_train_task, job_id, request)
|
||||||
|
|
||||||
|
return CNNTrainResponse(
|
||||||
|
job_id = job_id,
|
||||||
|
status = "pending",
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/train-cnn/{job_id}", response_model=CNNTrainResponse, summary="Résultat entraînement CNN")
|
||||||
|
def get_cnn_train_status(job_id: str):
|
||||||
|
"""Retourne l'état d'un job d'entraînement CNN."""
|
||||||
|
job = _cnn_train_jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||||
|
return CNNTrainResponse(job_id=job_id, **job)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cnn-models", summary="Liste des modèles CNN entraînés")
|
||||||
|
def list_cnn_models():
|
||||||
|
"""
|
||||||
|
Retourne la liste de tous les modèles CNN disponibles sur disque,
|
||||||
|
avec leurs métriques (accuracy, date d'entraînement, nombre de samples...).
|
||||||
|
"""
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
return {"error": "PyTorch requis — rebuilder le container trading-api", "models": [], "count": 0}
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
models = CNNStrategyModel.list_trained_models()
|
||||||
|
return {"models": models, "count": len(models)}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Ensemble — Configuration et statut du modèle d'ensemble (ML + CNN)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class EnsembleConfigRequest(BaseModel):
|
||||||
|
"""Configuration de l'ensemble ML + CNN."""
|
||||||
|
weights: dict = {"xgboost": 0.40, "cnn": 0.60}
|
||||||
|
min_confidence: float = 0.60
|
||||||
|
require_agreement: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration globale de l'ensemble (en mémoire)
|
||||||
|
_ensemble_config: Dict = {
|
||||||
|
"weights": {"xgboost": 0.40, "cnn": 0.60},
|
||||||
|
"min_confidence": 0.60,
|
||||||
|
"require_agreement": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/ensemble/configure", summary="Configurer l'ensemble ML + CNN")
|
||||||
|
async def configure_ensemble(request: EnsembleConfigRequest):
|
||||||
|
"""
|
||||||
|
Configure les poids et paramètres de l'ensemble ML + CNN.
|
||||||
|
|
||||||
|
- weights: poids relatifs de chaque composant (ex: {"xgboost": 0.40, "cnn": 0.60})
|
||||||
|
- min_confidence: seuil minimum de confiance de l'ensemble
|
||||||
|
- require_agreement: si True, les deux modèles doivent être d'accord sur la direction
|
||||||
|
"""
|
||||||
|
if not ENSEMBLE_AVAILABLE:
|
||||||
|
raise HTTPException(503, detail="Ensemble non disponible — modules manquants")
|
||||||
|
|
||||||
|
_ensemble_config.update({
|
||||||
|
"weights": request.weights,
|
||||||
|
"min_confidence": request.min_confidence,
|
||||||
|
"require_agreement": request.require_agreement,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Propager la config à la stratégie ensemble active si elle existe
|
||||||
|
engine = _paper_state.get("engine")
|
||||||
|
if engine and hasattr(engine, 'strategy_engine'):
|
||||||
|
strat = engine.strategy_engine.strategies.get('ensemble')
|
||||||
|
if strat and isinstance(strat, EnsembleStrategy):
|
||||||
|
strat.update_params({
|
||||||
|
"weights": request.weights,
|
||||||
|
"min_confidence": request.min_confidence,
|
||||||
|
"require_agreement": request.require_agreement,
|
||||||
|
})
|
||||||
|
logger.info("Configuration ensemble appliquée à la stratégie active")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "configured",
|
||||||
|
"config": _ensemble_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/ensemble/status", summary="Statut de l'ensemble ML + CNN")
|
||||||
|
async def get_ensemble_status():
|
||||||
|
"""
|
||||||
|
Retourne le statut de chaque composant de l'ensemble :
|
||||||
|
- Modèles ML (XGBoost/LightGBM) disponibles
|
||||||
|
- Modèles CNN disponibles
|
||||||
|
- Configuration active (poids, seuil, agreement)
|
||||||
|
- État de la stratégie ensemble si active en paper trading
|
||||||
|
"""
|
||||||
|
status = {
|
||||||
|
"config": _ensemble_config,
|
||||||
|
"components": {},
|
||||||
|
"paper_trading_active": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Vérifier les modèles ML disponibles
|
||||||
|
try:
|
||||||
|
from src.ml.ml_strategy_model import MLStrategyModel
|
||||||
|
ml_models = MLStrategyModel.list_trained_models()
|
||||||
|
status["components"]["ml"] = {
|
||||||
|
"available": True,
|
||||||
|
"models_count": len(ml_models),
|
||||||
|
"models": ml_models,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
status["components"]["ml"] = {"available": False, "models_count": 0, "models": []}
|
||||||
|
|
||||||
|
# Vérifier les modèles CNN disponibles
|
||||||
|
if CNN_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
cnn_models = CNNStrategyModel.list_trained_models()
|
||||||
|
status["components"]["cnn"] = {
|
||||||
|
"available": True,
|
||||||
|
"models_count": len(cnn_models),
|
||||||
|
"models": cnn_models,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
status["components"]["cnn"] = {"available": False, "models_count": 0, "models": []}
|
||||||
|
else:
|
||||||
|
status["components"]["cnn"] = {
|
||||||
|
"available": False,
|
||||||
|
"error": "PyTorch requis — rebuilder le container trading-api",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Vérifier si la stratégie ensemble est active en paper trading
|
||||||
|
engine = _paper_state.get("engine")
|
||||||
|
if engine and hasattr(engine, 'strategy_engine'):
|
||||||
|
strat = engine.strategy_engine.strategies.get('ensemble')
|
||||||
|
if strat and ENSEMBLE_AVAILABLE and isinstance(strat, EnsembleStrategy):
|
||||||
|
status["paper_trading_active"] = True
|
||||||
|
try:
|
||||||
|
status["ensemble_info"] = strat.get_status()
|
||||||
|
except Exception:
|
||||||
|
status["ensemble_info"] = {"error": "Impossible de récupérer le statut"}
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CNN IMAGE STRATEGY — Entraînement et gestion des modèles CNN-Image
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.ml.cnn_image import CNNImageStrategyModel
|
||||||
|
CNN_IMAGE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CNN_IMAGE_AVAILABLE = False
|
||||||
|
|
||||||
|
# Stockage en mémoire des jobs d'entraînement CNN Image
|
||||||
|
_cnn_image_train_jobs: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class CNNImageTrainRequest(BaseModel):
|
||||||
|
"""Requête d'entraînement du modèle CNN Image (Conv2D vision)."""
|
||||||
|
symbol: str = "EURUSD"
|
||||||
|
timeframe: str = "1h"
|
||||||
|
period: str = "2y"
|
||||||
|
seq_len: int = 64
|
||||||
|
tp_atr_mult: float = 2.0
|
||||||
|
sl_atr_mult: float = 1.0
|
||||||
|
horizon: int = 30
|
||||||
|
min_confidence: float = 0.55
|
||||||
|
|
||||||
|
|
||||||
|
class CNNImageTrainResponse(BaseModel):
|
||||||
|
"""Réponse d'un job d'entraînement CNN Image."""
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
symbol: str
|
||||||
|
timeframe: str
|
||||||
|
wf_accuracy: Optional[float] = None
|
||||||
|
wf_precision: Optional[float] = None
|
||||||
|
label_dist: Optional[dict] = None
|
||||||
|
n_samples: Optional[int] = None
|
||||||
|
trained_at: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_cnn_image_train_task(job_id: str, request: CNNImageTrainRequest) -> None:
|
||||||
|
"""Tâche d'entraînement CNN Image exécutée en arrière-plan."""
|
||||||
|
_cnn_image_train_jobs[job_id]["status"] = "running"
|
||||||
|
try:
|
||||||
|
from src.data.data_service import DataService
|
||||||
|
from src.utils.config_loader import ConfigLoader
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = ConfigLoader.load_all()
|
||||||
|
data_service = DataService(config)
|
||||||
|
|
||||||
|
end_date = datetime.now()
|
||||||
|
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||||
|
unit = request.period[-1]
|
||||||
|
value = int(request.period[:-1])
|
||||||
|
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||||
|
|
||||||
|
df = await data_service.get_historical_data(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
start_date = start_date,
|
||||||
|
end_date = end_date,
|
||||||
|
)
|
||||||
|
if df is None or len(df) < 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Entraînement dans un thread (opération CPU-bound / GPU-bound)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, _sync_cnn_image_train, df, request)
|
||||||
|
|
||||||
|
_cnn_image_train_jobs[job_id].update({
|
||||||
|
"status": "completed",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
"n_samples": result.get("n_samples"),
|
||||||
|
"wf_accuracy": result.get("wf_metrics", {}).get("avg_accuracy"),
|
||||||
|
"wf_precision": result.get("wf_metrics", {}).get("avg_precision"),
|
||||||
|
"label_dist": result.get("label_dist"),
|
||||||
|
"trained_at": result.get("trained_at"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Auto-attachement à la stratégie cnn_image_driven active si elle existe
|
||||||
|
_attach_cnn_image_model_to_strategy(request)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Erreur entraînement CNN Image job {job_id} : {exc}", exc_info=True)
|
||||||
|
_cnn_image_train_jobs[job_id]["status"] = "failed"
|
||||||
|
_cnn_image_train_jobs[job_id]["error"] = str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_cnn_image_train(df, request: CNNImageTrainRequest) -> dict:
|
||||||
|
"""Wrapper synchrone pour CNNImageStrategyModel.train() (exécuté dans un thread)."""
|
||||||
|
from src.ml.cnn_image import CNNImageStrategyModel
|
||||||
|
model = CNNImageStrategyModel(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
seq_len = request.seq_len,
|
||||||
|
tp_atr_mult = request.tp_atr_mult,
|
||||||
|
sl_atr_mult = request.sl_atr_mult,
|
||||||
|
horizon = request.horizon,
|
||||||
|
min_confidence = request.min_confidence,
|
||||||
|
)
|
||||||
|
return model.train(df)
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_cnn_image_model_to_strategy(request: CNNImageTrainRequest) -> None:
|
||||||
|
"""Attache le modèle CNN Image entraîné à la stratégie cnn_image_driven active (paper trading)."""
|
||||||
|
try:
|
||||||
|
from src.ml.cnn_image import CNNImageStrategyModel
|
||||||
|
from src.strategies.cnn_image_driven import CNNImageDrivenStrategy
|
||||||
|
|
||||||
|
engine = _paper_state.get("engine")
|
||||||
|
if engine and hasattr(engine, 'strategy_engine'):
|
||||||
|
strat = engine.strategy_engine.strategies.get('cnn_image_driven')
|
||||||
|
if strat and isinstance(strat, CNNImageDrivenStrategy):
|
||||||
|
model = CNNImageStrategyModel.load(request.symbol, request.timeframe)
|
||||||
|
strat.attach_model(model)
|
||||||
|
logger.info("Modèle CNN Image attaché à la stratégie cnn_image_driven active")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Auto-attach modèle CNN Image ignoré : {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train-cnn-image", response_model=CNNImageTrainResponse,
|
||||||
|
summary="Entraîner le modèle CNN Image (Conv2D vision)")
|
||||||
|
async def train_cnn_image(request: CNNImageTrainRequest, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
Lance l'entraînement du modèle CNN Image en arrière-plan.
|
||||||
|
|
||||||
|
Le CNN Image Conv2D apprend les patterns visuels des graphiques de chandeliers
|
||||||
|
(bougies marteau, double top/bottom, rebonds S/R, momentum...) depuis des
|
||||||
|
images 128×128 RGB générées par mplfinance — aucune feature pré-calculée.
|
||||||
|
|
||||||
|
- Retourne un `job_id` à interroger via `GET /trading/train-cnn-image/{job_id}`
|
||||||
|
- Le modèle est sauvegardé sur disque après entraînement
|
||||||
|
- Si un paper trading cnn_image_driven est actif, le modèle lui est automatiquement attaché
|
||||||
|
"""
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
raise HTTPException(
|
||||||
|
503,
|
||||||
|
detail="PyTorch + mplfinance requis — rebuilder le container trading-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
_cnn_image_train_jobs[job_id] = {
|
||||||
|
"status": "pending",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
}
|
||||||
|
|
||||||
|
# asyncio.create_task() plutôt que background_tasks pour permettre
|
||||||
|
# les hot-reloads WatchFiles sans bloquer à l'arrêt du serveur
|
||||||
|
asyncio.create_task(_run_cnn_image_train_task(job_id, request))
|
||||||
|
|
||||||
|
return CNNImageTrainResponse(
|
||||||
|
job_id = job_id,
|
||||||
|
status = "pending",
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/train-cnn-image/{job_id}", response_model=CNNImageTrainResponse,
|
||||||
|
summary="Résultat entraînement CNN Image")
|
||||||
|
def get_cnn_image_train_status(job_id: str):
|
||||||
|
"""Retourne l'état d'un job d'entraînement CNN Image."""
|
||||||
|
job = _cnn_image_train_jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||||
|
return CNNImageTrainResponse(job_id=job_id, **job)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cnn-image-models", summary="Liste des modèles CNN Image entraînés")
|
||||||
|
def list_cnn_image_models():
|
||||||
|
"""
|
||||||
|
Retourne la liste de tous les modèles CNN Image disponibles sur disque,
|
||||||
|
avec leurs métriques (accuracy, date d'entraînement, nombre de samples...).
|
||||||
|
"""
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
return {
|
||||||
|
"error": "PyTorch + mplfinance requis — rebuilder le container trading-api",
|
||||||
|
"models": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
|
from src.ml.cnn_image import CNNImageStrategyModel
|
||||||
|
models = CNNImageStrategyModel.list_trained_models()
|
||||||
|
return {"models": models, "count": len(models)}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# RL STRATEGY — Entraînement et gestion des modèles RL (PPO)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.ml.rl import RLStrategyModel, RL_AVAILABLE as _RL_AVAILABLE
|
||||||
|
RL_AVAILABLE = _RL_AVAILABLE
|
||||||
|
except ImportError:
|
||||||
|
RL_AVAILABLE = False
|
||||||
|
|
||||||
|
# Stockage en mémoire des jobs d'entraînement RL
|
||||||
|
_rl_train_jobs: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class RLTrainRequest(BaseModel):
|
||||||
|
"""Requête d'entraînement du modèle RL (PPO)."""
|
||||||
|
symbol: str = "EURUSD"
|
||||||
|
timeframe: str = "1h"
|
||||||
|
period: str = "2y" # Période de données historiques
|
||||||
|
total_timesteps: int = 100_000 # Nombre de timesteps PPO
|
||||||
|
sl_atr_mult: float = 1.0 # Multiplicateur ATR pour SL
|
||||||
|
tp_atr_mult: float = 2.0 # Multiplicateur ATR pour TP
|
||||||
|
min_confidence: float = 0.50 # Seuil confiance minimum
|
||||||
|
train_ratio: float = 0.80 # Fraction données pour l'entraînement
|
||||||
|
initial_capital: float = 10_000.0 # Capital initial pour la simulation
|
||||||
|
|
||||||
|
|
||||||
|
class RLTrainResponse(BaseModel):
|
||||||
|
"""Réponse d'un job d'entraînement RL."""
|
||||||
|
job_id: str
|
||||||
|
status: str # pending | running | completed | failed
|
||||||
|
symbol: Optional[str] = None
|
||||||
|
timeframe: Optional[str] = None
|
||||||
|
n_samples: Optional[int] = None
|
||||||
|
total_timesteps: Optional[int] = None
|
||||||
|
n_episodes: Optional[int] = None
|
||||||
|
mean_ep_return: Optional[float] = None # Récompense moyenne par épisode
|
||||||
|
eval_return_pct: Optional[float] = None # Rendement sur le holdout
|
||||||
|
eval_sharpe: Optional[float] = None # Sharpe approx. sur le holdout
|
||||||
|
eval_win_rate: Optional[float] = None # Win rate sur le holdout
|
||||||
|
trained_at: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_rl_train_task(job_id: str, request: RLTrainRequest) -> None:
|
||||||
|
"""Tâche d'entraînement RL exécutée en arrière-plan."""
|
||||||
|
_rl_train_jobs[job_id]["status"] = "running"
|
||||||
|
try:
|
||||||
|
from src.data.data_service import DataService
|
||||||
|
from src.utils.config_loader import ConfigLoader
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = ConfigLoader.load_all()
|
||||||
|
data_service = DataService(config)
|
||||||
|
|
||||||
|
end_date = datetime.now()
|
||||||
|
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||||
|
unit = request.period[-1]
|
||||||
|
value = int(request.period[:-1])
|
||||||
|
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||||
|
|
||||||
|
df = await data_service.get_historical_data(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
start_date = start_date,
|
||||||
|
end_date = end_date,
|
||||||
|
)
|
||||||
|
if df is None or len(df) < 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Entraînement dans un thread (opération CPU-bound / GPU-bound)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, _sync_rl_train, df, request)
|
||||||
|
|
||||||
|
_rl_train_jobs[job_id].update({
|
||||||
|
"status": "completed",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
"n_samples": result.get("n_samples"),
|
||||||
|
"total_timesteps": result.get("total_timesteps"),
|
||||||
|
"n_episodes": result.get("n_episodes"),
|
||||||
|
"mean_ep_return": result.get("mean_ep_return"),
|
||||||
|
"eval_return_pct": result.get("eval_return_pct"),
|
||||||
|
"eval_sharpe": result.get("eval_sharpe"),
|
||||||
|
"eval_win_rate": result.get("eval_win_rate"),
|
||||||
|
"trained_at": result.get("trained_at"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Auto-attachement à la stratégie RL active si elle existe
|
||||||
|
_attach_rl_model_to_strategy(request)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Erreur entraînement RL job {job_id} : {exc}", exc_info=True)
|
||||||
|
_rl_train_jobs[job_id]["status"] = "failed"
|
||||||
|
_rl_train_jobs[job_id]["error"] = str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_rl_train(df, request: RLTrainRequest) -> dict:
|
||||||
|
"""Wrapper synchrone pour RLStrategyModel.train() (exécuté dans un thread)."""
|
||||||
|
from src.ml.rl import RLStrategyModel
|
||||||
|
model = RLStrategyModel(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
total_timesteps = request.total_timesteps,
|
||||||
|
sl_atr_mult = request.sl_atr_mult,
|
||||||
|
tp_atr_mult = request.tp_atr_mult,
|
||||||
|
min_confidence = request.min_confidence,
|
||||||
|
train_ratio = request.train_ratio,
|
||||||
|
initial_capital = request.initial_capital,
|
||||||
|
)
|
||||||
|
return model.train(df)
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_rl_model_to_strategy(request: RLTrainRequest) -> None:
|
||||||
|
"""Attache le modèle RL entraîné à la stratégie rl_driven active (paper trading)."""
|
||||||
|
try:
|
||||||
|
from src.ml.rl import RLStrategyModel
|
||||||
|
from src.strategies.rl_driven import RLDrivenStrategy
|
||||||
|
|
||||||
|
engine = _paper_state.get("engine")
|
||||||
|
if engine and hasattr(engine, 'strategy_engine'):
|
||||||
|
strat = engine.strategy_engine.strategies.get('rl_driven')
|
||||||
|
if strat and isinstance(strat, RLDrivenStrategy):
|
||||||
|
model = RLStrategyModel.load(request.symbol, request.timeframe)
|
||||||
|
strat.attach_model(model)
|
||||||
|
logger.info("Modèle RL attaché à la stratégie rl_driven active")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Auto-attach modèle RL ignoré : {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train-rl", response_model=RLTrainResponse,
|
||||||
|
summary="Entraîner le modèle RL (PPO)")
|
||||||
|
async def train_rl_model(request: RLTrainRequest, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
Lance l'entraînement de l'agent PPO en arrière-plan.
|
||||||
|
|
||||||
|
L'agent RL (Proximal Policy Optimization) apprend à trader par interaction
|
||||||
|
directe avec un environnement de trading simulé — sans labels supervisés.
|
||||||
|
La récompense est définie par le PnL réalisé et les pénalités de drawdown.
|
||||||
|
|
||||||
|
- Retourne un `job_id` à interroger via `GET /trading/train-rl/{job_id}`
|
||||||
|
- Le modèle est sauvegardé sur disque après entraînement
|
||||||
|
- Si un paper trading rl_driven est actif, le modèle lui est automatiquement attaché
|
||||||
|
"""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
raise HTTPException(
|
||||||
|
503,
|
||||||
|
detail="PyTorch requis — rebuilder le container trading-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
_rl_train_jobs[job_id] = {
|
||||||
|
"status": "pending",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
}
|
||||||
|
|
||||||
|
background_tasks.add_task(_run_rl_train_task, job_id, request)
|
||||||
|
|
||||||
|
return RLTrainResponse(
|
||||||
|
job_id = job_id,
|
||||||
|
status = "pending",
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/train-rl/{job_id}", response_model=RLTrainResponse,
|
||||||
|
summary="Résultat entraînement RL")
|
||||||
|
def get_rl_train_status(job_id: str):
|
||||||
|
"""Retourne l'état d'un job d'entraînement RL (PPO)."""
|
||||||
|
job = _rl_train_jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||||
|
return RLTrainResponse(job_id=job_id, **job)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/rl-models", summary="Liste des modèles RL entraînés")
|
||||||
|
def list_rl_models():
|
||||||
|
"""
|
||||||
|
Retourne la liste de tous les modèles RL (PPO) disponibles sur disque,
|
||||||
|
avec leurs métriques (return holdout, Sharpe, date d'entraînement...).
|
||||||
|
"""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
return {
|
||||||
|
"error": "PyTorch requis — rebuilder le container trading-api",
|
||||||
|
"models": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
from src.ml.rl import RLStrategyModel
|
||||||
|
models = RLStrategyModel.list_trained_models()
|
||||||
|
return {"models": models, "count": len(models)}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e), "models": [], "count": 0}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# RL STRATEGY — Entraînement et gestion des modèles PPO (Reinforcement Learning)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||||
|
RL_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
RLStrategyModel = None
|
||||||
|
RL_AVAILABLE = False
|
||||||
|
|
||||||
|
# Stockage en mémoire des jobs d'entraînement RL
|
||||||
|
_rl_train_jobs: Dict[str, dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class RLTrainRequest(BaseModel):
|
||||||
|
"""Requête d'entraînement du modèle RL (agent PPO)."""
|
||||||
|
symbol: str = "EURUSD"
|
||||||
|
timeframe: str = "1h"
|
||||||
|
period: str = "2y"
|
||||||
|
total_timesteps: int = 50000
|
||||||
|
|
||||||
|
|
||||||
|
class RLTrainResponse(BaseModel):
|
||||||
|
"""Réponse d'un job d'entraînement RL."""
|
||||||
|
job_id: str
|
||||||
|
status: str
|
||||||
|
symbol: str
|
||||||
|
timeframe: str
|
||||||
|
avg_reward: Optional[float] = None
|
||||||
|
sharpe_env: Optional[float] = None
|
||||||
|
total_timesteps: Optional[int] = None
|
||||||
|
trained_at: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_rl_train_task(job_id: str, request: RLTrainRequest) -> None:
|
||||||
|
"""Tâche d'entraînement RL exécutée en arrière-plan."""
|
||||||
|
_rl_train_jobs[job_id]["status"] = "running"
|
||||||
|
try:
|
||||||
|
from src.data.data_service import DataService
|
||||||
|
from src.utils.config_loader import ConfigLoader
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = ConfigLoader.load_all()
|
||||||
|
data_service = DataService(config)
|
||||||
|
|
||||||
|
end_date = datetime.now()
|
||||||
|
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||||
|
unit = request.period[-1]
|
||||||
|
value = int(request.period[:-1])
|
||||||
|
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||||
|
|
||||||
|
df = await data_service.get_historical_data(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
start_date = start_date,
|
||||||
|
end_date = end_date,
|
||||||
|
)
|
||||||
|
if df is None or len(df) < 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Entraînement dans un thread (opération CPU-bound)
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, _sync_rl_train, df, request)
|
||||||
|
|
||||||
|
_rl_train_jobs[job_id].update({
|
||||||
|
"status": "completed",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
"avg_reward": result.get("avg_reward"),
|
||||||
|
"sharpe_env": result.get("sharpe_env"),
|
||||||
|
"total_timesteps": result.get("total_timesteps"),
|
||||||
|
"trained_at": result.get("trained_at"),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Auto-attachement à la stratégie rl_driven active si elle existe
|
||||||
|
_attach_rl_model_to_strategy(request)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Erreur entraînement RL job {job_id} : {exc}", exc_info=True)
|
||||||
|
_rl_train_jobs[job_id]["status"] = "failed"
|
||||||
|
_rl_train_jobs[job_id]["error"] = str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_rl_train(df, request: RLTrainRequest) -> dict:
|
||||||
|
"""Wrapper synchrone pour RLStrategyModel.train() (exécuté dans un thread)."""
|
||||||
|
import torch
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||||
|
model = RLStrategyModel(
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
total_timesteps = request.total_timesteps,
|
||||||
|
)
|
||||||
|
return model.train(df)
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_rl_model_to_strategy(request: RLTrainRequest) -> None:
|
||||||
|
"""Attache le modèle RL entraîné à la stratégie rl_driven active (paper trading)."""
|
||||||
|
try:
|
||||||
|
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||||
|
from src.strategies.rl_driven import RLDrivenStrategy
|
||||||
|
|
||||||
|
engine = _paper_state.get("engine")
|
||||||
|
if engine and hasattr(engine, 'strategy_engine'):
|
||||||
|
strat = engine.strategy_engine.strategies.get('rl_driven')
|
||||||
|
if strat and isinstance(strat, RLDrivenStrategy):
|
||||||
|
model = RLStrategyModel.load(request.symbol, request.timeframe)
|
||||||
|
strat.attach_model(model)
|
||||||
|
logger.info("Modèle RL attaché à la stratégie rl_driven active")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Auto-attach modèle RL ignoré : {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/train-rl", response_model=RLTrainResponse,
|
||||||
|
summary="Entraîner le modèle RL (agent PPO)")
|
||||||
|
async def train_rl(request: RLTrainRequest, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
Lance l'entraînement de l'agent PPO en arrière-plan.
|
||||||
|
|
||||||
|
L'agent Reinforcement Learning apprend à maximiser le profit cumulatif en
|
||||||
|
interagissant avec un environnement de trading simulé (TradingEnv) — sans
|
||||||
|
labels supervisés. La politique Actor-Critic PPO est optimisée sur
|
||||||
|
`total_timesteps` pas de simulation.
|
||||||
|
|
||||||
|
- Retourne un `job_id` à interroger via `GET /trading/train-rl/{job_id}`
|
||||||
|
- Le modèle est sauvegardé sur disque après entraînement
|
||||||
|
- Si un paper trading rl_driven est actif, le modèle lui est automatiquement attaché
|
||||||
|
"""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
raise HTTPException(
|
||||||
|
503,
|
||||||
|
detail="PyTorch / stable-baselines3 requis — rebuilder le container trading-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
_rl_train_jobs[job_id] = {
|
||||||
|
"status": "pending",
|
||||||
|
"symbol": request.symbol,
|
||||||
|
"timeframe": request.timeframe,
|
||||||
|
}
|
||||||
|
|
||||||
|
background_tasks.add_task(_run_rl_train_task, job_id, request)
|
||||||
|
|
||||||
|
return RLTrainResponse(
|
||||||
|
job_id = job_id,
|
||||||
|
status = "pending",
|
||||||
|
symbol = request.symbol,
|
||||||
|
timeframe = request.timeframe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/train-rl/{job_id}", response_model=RLTrainResponse,
|
||||||
|
summary="Résultat entraînement RL")
|
||||||
|
def get_rl_train_status(job_id: str):
|
||||||
|
"""Retourne l'état d'un job d'entraînement RL (PPO)."""
|
||||||
|
job = _rl_train_jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||||
|
return RLTrainResponse(job_id=job_id, **job)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/rl-models", summary="Liste des modèles RL entraînés")
|
||||||
|
def list_rl_models():
|
||||||
|
"""
|
||||||
|
Retourne la liste de tous les modèles RL disponibles sur disque,
|
||||||
|
avec leurs métriques (avg_reward, sharpe_env, date d'entraînement...).
|
||||||
|
"""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
return {
|
||||||
|
"error": "PyTorch / stable-baselines3 requis — rebuilder le container trading-api",
|
||||||
|
"models": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
|
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||||
|
models = RLStrategyModel.list_trained_models()
|
||||||
|
return {"models": models, "count": len(models)}
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ class StrategyEngine:
|
|||||||
elif strategy_name == 'swing':
|
elif strategy_name == 'swing':
|
||||||
from src.strategies.swing.swing_strategy import SwingStrategy
|
from src.strategies.swing.swing_strategy import SwingStrategy
|
||||||
strategy_class = SwingStrategy
|
strategy_class = SwingStrategy
|
||||||
|
elif strategy_name == 'ml_driven':
|
||||||
|
from src.strategies.ml_driven.ml_strategy import MLDrivenStrategy
|
||||||
|
strategy_class = MLDrivenStrategy
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown strategy: {strategy_name}")
|
raise ValueError(f"Unknown strategy: {strategy_name}")
|
||||||
|
|
||||||
|
|||||||
3
src/ml/cnn/__init__.py
Normal file
3
src/ml/cnn/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .cnn_strategy_model import CNNStrategyModel
|
||||||
|
|
||||||
|
__all__ = ['CNNStrategyModel']
|
||||||
129
src/ml/cnn/candlestick_encoder.py
Normal file
129
src/ml/cnn/candlestick_encoder.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Encodeur de bougies OHLCV en séquences normalisées pour CNN 1D.
|
||||||
|
|
||||||
|
Transforme un DataFrame OHLCV en tenseurs (N, seq_len, 5) prêts pour le CNN.
|
||||||
|
Chaque séquence est normalisée indépendamment (z-score glissant) pour que
|
||||||
|
le modèle apprenne des patterns relatifs, pas des niveaux de prix absolus.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CandlestickEncoder:
|
||||||
|
"""
|
||||||
|
Encode les données OHLCV brutes en séquences normalisées pour le CNN.
|
||||||
|
|
||||||
|
La normalisation par fenêtre glissante garantit que le CNN voit des
|
||||||
|
patterns de forme (doji, engulfing, etc.) indépendamment du niveau de prix.
|
||||||
|
Le volume est normalisé séparément (ratio vs moyenne).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_len: Longueur des séquences (nombre de bougies par échantillon)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seq_len: int = 64):
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
|
def encode(self, df_ohlcv: pd.DataFrame, seq_len: int = None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode toutes les séquences glissantes depuis le DataFrame OHLCV.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_ohlcv: DataFrame avec colonnes open, high, low, close, volume
|
||||||
|
seq_len: Override de la longueur de séquence (optionnel)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de shape (N_samples, seq_len, 5)
|
||||||
|
Colonnes : [open, high, low, close, volume]
|
||||||
|
"""
|
||||||
|
seq_len = seq_len or self.seq_len
|
||||||
|
df = self._prepare_df(df_ohlcv)
|
||||||
|
|
||||||
|
if len(df) < seq_len + 1:
|
||||||
|
logger.warning(f"Pas assez de données : {len(df)} barres < {seq_len + 1} minimum")
|
||||||
|
return np.empty((0, seq_len, 5))
|
||||||
|
|
||||||
|
n_samples = len(df) - seq_len + 1
|
||||||
|
sequences = np.zeros((n_samples, seq_len, 5), dtype=np.float32)
|
||||||
|
|
||||||
|
for i in range(n_samples):
|
||||||
|
window = df.iloc[i:i + seq_len]
|
||||||
|
sequences[i] = self._normalize_window(window)
|
||||||
|
|
||||||
|
# Supprimer les séquences avec NaN
|
||||||
|
valid_mask = ~np.isnan(sequences).any(axis=(1, 2))
|
||||||
|
sequences = sequences[valid_mask]
|
||||||
|
|
||||||
|
logger.info(f"Encodage : {len(sequences)} séquences de {seq_len} bougies")
|
||||||
|
return sequences
|
||||||
|
|
||||||
|
def encode_last(self, df_ohlcv: pd.DataFrame, seq_len: int = None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode uniquement la dernière séquence (pour prédiction temps réel).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_ohlcv: DataFrame OHLCV (au moins seq_len barres)
|
||||||
|
seq_len: Override de la longueur de séquence (optionnel)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de shape (1, seq_len, 5)
|
||||||
|
"""
|
||||||
|
seq_len = seq_len or self.seq_len
|
||||||
|
df = self._prepare_df(df_ohlcv)
|
||||||
|
|
||||||
|
if len(df) < seq_len:
|
||||||
|
logger.warning(f"Pas assez de données pour encode_last : {len(df)} < {seq_len}")
|
||||||
|
return np.empty((0, seq_len, 5))
|
||||||
|
|
||||||
|
window = df.iloc[-seq_len:]
|
||||||
|
normalized = self._normalize_window(window)
|
||||||
|
return normalized.reshape(1, seq_len, 5)
|
||||||
|
|
||||||
|
def _prepare_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Prépare le DataFrame : colonnes en minuscules, sélection OHLCV."""
|
||||||
|
df = df.copy()
|
||||||
|
df.columns = [c.lower() for c in df.columns]
|
||||||
|
|
||||||
|
required = ['open', 'high', 'low', 'close', 'volume']
|
||||||
|
missing = [c for c in required if c not in df.columns]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Colonnes manquantes : {missing}")
|
||||||
|
|
||||||
|
return df[required].reset_index(drop=True)
|
||||||
|
|
||||||
|
def _normalize_window(self, window: pd.DataFrame) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalise une fenêtre OHLCV.
|
||||||
|
|
||||||
|
Prix (OHLC) : z-score sur la fenêtre (moyenne et écart-type du close)
|
||||||
|
Volume : ratio par rapport à la moyenne du volume sur la fenêtre
|
||||||
|
"""
|
||||||
|
result = np.zeros((len(window), 5), dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalisation prix par z-score du close
|
||||||
|
close_values = window['close'].values.astype(np.float64)
|
||||||
|
mean_price = close_values.mean()
|
||||||
|
std_price = close_values.std()
|
||||||
|
|
||||||
|
if std_price < 1e-10:
|
||||||
|
# Prix constant → tout à zéro
|
||||||
|
std_price = 1.0
|
||||||
|
|
||||||
|
result[:, 0] = (window['open'].values - mean_price) / std_price
|
||||||
|
result[:, 1] = (window['high'].values - mean_price) / std_price
|
||||||
|
result[:, 2] = (window['low'].values - mean_price) / std_price
|
||||||
|
result[:, 3] = (window['close'].values - mean_price) / std_price
|
||||||
|
|
||||||
|
# Normalisation volume : ratio vs moyenne
|
||||||
|
vol_values = window['volume'].values.astype(np.float64)
|
||||||
|
mean_vol = vol_values.mean()
|
||||||
|
if mean_vol < 1e-10:
|
||||||
|
result[:, 4] = 0.0
|
||||||
|
else:
|
||||||
|
result[:, 4] = vol_values / mean_vol
|
||||||
|
|
||||||
|
return result
|
||||||
113
src/ml/cnn/cnn_model.py
Normal file
113
src/ml/cnn/cnn_model.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
CNN 1D pour la détection de patterns dans les bougies OHLCV.
|
||||||
|
|
||||||
|
Architecture Conv1d → BatchNorm → ReLU → Pool, empilée sur 3 couches,
|
||||||
|
suivie d'un classifieur linéaire pour prédire LONG / SHORT / NEUTRAL.
|
||||||
|
|
||||||
|
Conçu pour capturer des patterns visuels (doji, engulfing, head&shoulders...)
|
||||||
|
que le réseau apprend directement depuis les données brutes normalisées,
|
||||||
|
sans features pré-calculées (contrairement à XGBoost).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
logger.warning("PyTorch non disponible — le CNN ne peut pas être utilisé")
|
||||||
|
|
||||||
|
|
||||||
|
if TORCH_AVAILABLE:
|
||||||
|
class TradingCNN(nn.Module):
|
||||||
|
"""
|
||||||
|
CNN 1D pour classification de séquences OHLCV.
|
||||||
|
|
||||||
|
Architecture :
|
||||||
|
Input (batch, seq_len=64, 5)
|
||||||
|
→ Permute → (batch, 5, 64)
|
||||||
|
→ Conv1d(5→32, k=3) + BN + ReLU + MaxPool(2) → (batch, 32, 32)
|
||||||
|
→ Conv1d(32→64, k=3) + BN + ReLU + MaxPool(2) → (batch, 64, 16)
|
||||||
|
→ Conv1d(64→128, k=3) + BN + ReLU + AdaptiveAvgPool(1) → (batch, 128, 1)
|
||||||
|
→ Flatten → Linear(128→64) + ReLU + Dropout(0.3)
|
||||||
|
→ Linear(64→3) : logits [LONG, SHORT, NEUTRAL]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_features: Nombre de canaux d'entrée (5 = OHLCV)
|
||||||
|
n_classes: Nombre de classes de sortie (3)
|
||||||
|
dropout: Taux de dropout dans le classifieur
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_features: int = 5, n_classes: int = 3, dropout: float = 0.3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Couches convolutives
|
||||||
|
self.conv1 = nn.Conv1d(n_features, 32, kernel_size=3, padding=1)
|
||||||
|
self.bn1 = nn.BatchNorm1d(32)
|
||||||
|
self.pool1 = nn.MaxPool1d(2)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
|
||||||
|
self.bn2 = nn.BatchNorm1d(64)
|
||||||
|
self.pool2 = nn.MaxPool1d(2)
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
|
||||||
|
self.bn3 = nn.BatchNorm1d(128)
|
||||||
|
self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
|
||||||
|
|
||||||
|
# Classifieur
|
||||||
|
self.fc1 = nn.Linear(128, 64)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.fc2 = nn.Linear(64, n_classes)
|
||||||
|
|
||||||
|
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||||
|
"""
|
||||||
|
Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor de shape (batch, seq_len, n_features)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logits de shape (batch, n_classes)
|
||||||
|
"""
|
||||||
|
# Permute pour Conv1d : (batch, n_features, seq_len)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# Bloc 1
|
||||||
|
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
|
||||||
|
# Bloc 2
|
||||||
|
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
|
||||||
|
# Bloc 3
|
||||||
|
x = self.adaptive_pool(F.relu(self.bn3(self.conv3(x))))
|
||||||
|
|
||||||
|
# Flatten + classifieur
|
||||||
|
x = x.squeeze(-1) # (batch, 128)
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def predict_proba(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||||
|
"""
|
||||||
|
Retourne les probabilités de chaque classe via Softmax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor de shape (batch, seq_len, n_features)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Probabilités de shape (batch, n_classes)
|
||||||
|
"""
|
||||||
|
logits = self.forward(x)
|
||||||
|
return F.softmax(logits, dim=1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Placeholder si PyTorch non disponible
|
||||||
|
class TradingCNN:
|
||||||
|
"""Placeholder — PyTorch non disponible."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise RuntimeError("PyTorch non disponible — impossible de créer TradingCNN")
|
||||||
585
src/ml/cnn/cnn_strategy_model.py
Normal file
585
src/ml/cnn/cnn_strategy_model.py
Normal file
@@ -0,0 +1,585 @@
|
|||||||
|
"""
|
||||||
|
CNN Strategy Model — Modèle CNN 1D qui apprend des patterns de bougies.
|
||||||
|
|
||||||
|
Contrairement à MLStrategyModel (XGBoost sur features TA pré-calculées),
|
||||||
|
ce modèle travaille directement sur les séquences OHLCV brutes normalisées.
|
||||||
|
Le CNN détecte lui-même les patterns visuels pertinents (doji, engulfing, etc.).
|
||||||
|
|
||||||
|
Pipeline :
|
||||||
|
1. Chargement données OHLCV
|
||||||
|
2. Encodage séquences (CandlestickEncoder : z-score glissant)
|
||||||
|
3. Génération labels (LabelGenerator — partagé avec MLStrategyModel)
|
||||||
|
4. Entraînement CNN (PyTorch, Adam, CrossEntropy, early stopping)
|
||||||
|
5. Walk-forward validation (2 folds temporels)
|
||||||
|
6. Sauvegarde state_dict + métadonnées JSON
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
model = CNNStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||||
|
result = model.train(df_ohlcv)
|
||||||
|
signal = model.predict(df_recent)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.ml.features.label_generator import LabelGenerator
|
||||||
|
from src.ml.cnn.candlestick_encoder import CandlestickEncoder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Répertoire de sauvegarde des modèles CNN
|
||||||
|
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_strategy"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import TensorDataset, DataLoader
|
||||||
|
from src.ml.cnn.cnn_model import TradingCNN
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
logger.warning("PyTorch non disponible — CNNStrategyModel ne peut pas fonctionner")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.metrics import precision_recall_fscore_support
|
||||||
|
SKLEARN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
SKLEARN_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping indices CNN → signaux de trading
|
||||||
|
# CNN : 0=LONG, 1=SHORT, 2=NEUTRAL
|
||||||
|
# Trading : 1=LONG, -1=SHORT, 0=NEUTRAL
|
||||||
|
CLASS_MAP = {0: 1, 1: -1, 2: 0}
|
||||||
|
|
||||||
|
# Mapping inverse : labels LabelGenerator → indices CNN
|
||||||
|
# LabelGenerator : 1=LONG, -1=SHORT, 0=NEUTRAL
|
||||||
|
LABEL_TO_INDEX = {1: 0, -1: 1, 0: 2}
|
||||||
|
|
||||||
|
|
||||||
|
class CNNStrategyModel:
|
||||||
|
"""
|
||||||
|
Modèle CNN qui apprend les patterns visuels des bougies.
|
||||||
|
|
||||||
|
Le modèle :
|
||||||
|
- Travaille sur les 64 dernières bougies OHLCV (données brutes normalisées)
|
||||||
|
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||||
|
- Donne un score de confiance [0..1] par prédiction
|
||||||
|
- Se sauvegarde sur disque (state_dict PyTorch + métadonnées JSON)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire tradée (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h', '15m')
|
||||||
|
seq_len: Longueur des séquences d'entrée
|
||||||
|
min_confidence: Seuil de confiance pour signal tradeable
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour TP (labels)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour SL (labels)
|
||||||
|
horizon: Nombre de barres pour évaluer TP/SL (labels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
symbol: str = 'EURUSD',
|
||||||
|
timeframe: str = '1h',
|
||||||
|
seq_len: int = 64,
|
||||||
|
min_confidence: float = 0.55,
|
||||||
|
tp_atr_mult: float = 2.0,
|
||||||
|
sl_atr_mult: float = 1.0,
|
||||||
|
horizon: int = 30,
|
||||||
|
):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.timeframe = timeframe
|
||||||
|
self.model_type = 'cnn'
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
self.tp_atr_mult = tp_atr_mult
|
||||||
|
self.sl_atr_mult = sl_atr_mult
|
||||||
|
self.horizon = horizon
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.is_trained = False
|
||||||
|
self.metadata: Dict = {}
|
||||||
|
self.encoder = CandlestickEncoder(seq_len=seq_len)
|
||||||
|
|
||||||
|
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Entraînement
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def train(self, data: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Entraîne le CNN sur les données OHLCV.
|
||||||
|
|
||||||
|
Utilise les mêmes labels que MLStrategyModel (LabelGenerator.generate_atr_based).
|
||||||
|
Walk-forward validation sur 2 folds temporels.
|
||||||
|
Sauvegarde automatiquement le modèle après entraînement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame OHLCV (au moins 200 barres recommandées)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict avec wf_metrics, label_dist, n_samples, trained_at
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {'error': 'PyTorch non disponible'}
|
||||||
|
|
||||||
|
logger.info(f"Début entraînement CNNStrategyModel pour {self.symbol}/{self.timeframe}")
|
||||||
|
logger.info(f" Données : {len(data)} barres, seq_len={self.seq_len}")
|
||||||
|
|
||||||
|
# 1. Génération des labels (même méthode que MLStrategyModel)
|
||||||
|
gen = LabelGenerator(horizon=self.horizon)
|
||||||
|
labels_series = gen.generate_atr_based(
|
||||||
|
data,
|
||||||
|
atr_tp_mult=self.tp_atr_mult,
|
||||||
|
atr_sl_mult=self.sl_atr_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Encodage des séquences OHLCV
|
||||||
|
sequences = self.encoder.encode(data, seq_len=self.seq_len)
|
||||||
|
|
||||||
|
if len(sequences) == 0:
|
||||||
|
return {'error': 'Pas assez de données pour encoder des séquences'}
|
||||||
|
|
||||||
|
# 3. Aligner labels sur les séquences
|
||||||
|
# Chaque séquence[i] correspond aux barres [i : i+seq_len]
|
||||||
|
# Le label associé est celui de la dernière barre de la séquence (i + seq_len - 1)
|
||||||
|
n_samples = len(sequences)
|
||||||
|
label_indices = []
|
||||||
|
for i in range(n_samples):
|
||||||
|
target_idx = i + self.seq_len - 1
|
||||||
|
if target_idx < len(labels_series):
|
||||||
|
label_indices.append(target_idx)
|
||||||
|
else:
|
||||||
|
label_indices.append(None)
|
||||||
|
|
||||||
|
# Filtrer les séquences valides (label disponible et pas en fin de horizon)
|
||||||
|
valid_mask = []
|
||||||
|
labels_aligned = []
|
||||||
|
for i, idx in enumerate(label_indices):
|
||||||
|
if idx is not None and idx < len(labels_series) - self.horizon:
|
||||||
|
raw_label = labels_series.iloc[idx]
|
||||||
|
labels_aligned.append(LABEL_TO_INDEX.get(raw_label, 2))
|
||||||
|
valid_mask.append(i)
|
||||||
|
|
||||||
|
if len(valid_mask) < 50:
|
||||||
|
return {'error': f'Trop peu de données valides : {len(valid_mask)} échantillons'}
|
||||||
|
|
||||||
|
X = sequences[valid_mask]
|
||||||
|
y = np.array(labels_aligned, dtype=np.int64)
|
||||||
|
|
||||||
|
logger.info(f" {len(X)} échantillons après alignement")
|
||||||
|
n_long = (y == 0).sum()
|
||||||
|
n_short = (y == 1).sum()
|
||||||
|
n_neutral = (y == 2).sum()
|
||||||
|
logger.info(f" Distribution : LONG={n_long}, SHORT={n_short}, NEUTRAL={n_neutral}")
|
||||||
|
|
||||||
|
# 4. Walk-forward validation (2 folds)
|
||||||
|
wf_metrics = self._walk_forward_eval(X, y, n_folds=2)
|
||||||
|
|
||||||
|
# 5. Entraînement final sur toutes les données
|
||||||
|
self.model = TradingCNN(n_features=5, n_classes=3)
|
||||||
|
class_weights = self._compute_class_weights(y)
|
||||||
|
self._train_model(self.model, X, y, class_weights, max_epochs=100, patience=10)
|
||||||
|
self.is_trained = True
|
||||||
|
|
||||||
|
# 6. Métadonnées
|
||||||
|
self.metadata = {
|
||||||
|
'symbol': self.symbol,
|
||||||
|
'timeframe': self.timeframe,
|
||||||
|
'model_type': self.model_type,
|
||||||
|
'trained_at': datetime.utcnow().isoformat(),
|
||||||
|
'n_samples': len(X),
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'horizon': self.horizon,
|
||||||
|
'label_dist': {
|
||||||
|
'long': int(n_long),
|
||||||
|
'short': int(n_short),
|
||||||
|
'neutral': int(n_neutral),
|
||||||
|
},
|
||||||
|
'wf_metrics': wf_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 7. Sauvegarde
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
logger.info(f"Entraînement CNN terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}")
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Prédiction
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def predict(self, data: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Prédit le signal pour les dernières barres.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame OHLCV récent (au moins seq_len barres)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict : {
|
||||||
|
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
|
||||||
|
'confidence': float [0..1],
|
||||||
|
'probas': {'long': float, 'short': float, 'neutral': float},
|
||||||
|
'tradeable': bool (confidence >= min_confidence et signal != 0)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not self.is_trained or self.model is None:
|
||||||
|
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'Modèle non entraîné'}
|
||||||
|
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'PyTorch non disponible'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Encoder la dernière séquence
|
||||||
|
seq = self.encoder.encode_last(data, seq_len=self.seq_len)
|
||||||
|
if len(seq) == 0:
|
||||||
|
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'Données insuffisantes'}
|
||||||
|
|
||||||
|
# Prédiction
|
||||||
|
self.model.eval()
|
||||||
|
x_tensor = torch.FloatTensor(seq)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
probas_tensor = self.model.predict_proba(x_tensor)
|
||||||
|
probas_np = probas_tensor.numpy()[0]
|
||||||
|
|
||||||
|
# Mapping vers signaux de trading
|
||||||
|
pred_idx = int(np.argmax(probas_np))
|
||||||
|
signal = CLASS_MAP[pred_idx]
|
||||||
|
confidence = float(probas_np[pred_idx])
|
||||||
|
|
||||||
|
probas = {
|
||||||
|
'long': float(probas_np[0]),
|
||||||
|
'short': float(probas_np[1]),
|
||||||
|
'neutral': float(probas_np[2]),
|
||||||
|
}
|
||||||
|
|
||||||
|
tradeable = confidence >= self.min_confidence and signal != 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'signal': signal,
|
||||||
|
'confidence': confidence,
|
||||||
|
'probas': probas,
|
||||||
|
'tradeable': tradeable,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Erreur prédiction CNN : {e}")
|
||||||
|
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': str(e)}
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Sauvegarde / Chargement
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def save(self) -> None:
|
||||||
|
"""Sauvegarde le state_dict PyTorch + métadonnées JSON."""
|
||||||
|
if not TORCH_AVAILABLE or not self.is_trained or self.model is None:
|
||||||
|
raise RuntimeError("Modèle non entraîné")
|
||||||
|
|
||||||
|
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
model_id = f"{self.symbol}_{self.timeframe}_cnn"
|
||||||
|
model_path = MODELS_DIR / f"{model_id}.pt"
|
||||||
|
meta_path = MODELS_DIR / f"{model_id}_meta.json"
|
||||||
|
|
||||||
|
# Sauvegarder state_dict PyTorch
|
||||||
|
torch.save({
|
||||||
|
'state_dict': self.model.state_dict(),
|
||||||
|
'config': {
|
||||||
|
'symbol': self.symbol,
|
||||||
|
'timeframe': self.timeframe,
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'min_confidence': self.min_confidence,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'horizon': self.horizon,
|
||||||
|
},
|
||||||
|
}, model_path)
|
||||||
|
|
||||||
|
# Sauvegarder métadonnées JSON
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
json.dump(self.metadata, f, indent=2, default=str)
|
||||||
|
|
||||||
|
logger.info(f"Modèle CNN sauvegardé : {model_path}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, symbol: str, timeframe: str) -> 'CNNStrategyModel':
|
||||||
|
"""
|
||||||
|
Charge un modèle CNN depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance CNNStrategyModel prête à prédire
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError si le modèle n'existe pas
|
||||||
|
RuntimeError si PyTorch non disponible
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise RuntimeError("PyTorch non disponible")
|
||||||
|
|
||||||
|
model_id = f"{symbol}_{timeframe}_cnn"
|
||||||
|
model_path = MODELS_DIR / f"{model_id}.pt"
|
||||||
|
meta_path = MODELS_DIR / f"{model_id}_meta.json"
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Modèle CNN non trouvé : {model_path}")
|
||||||
|
|
||||||
|
# Charger le checkpoint
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||||
|
cfg = checkpoint.get('config', {})
|
||||||
|
|
||||||
|
instance = cls(
|
||||||
|
symbol=cfg.get('symbol', symbol),
|
||||||
|
timeframe=cfg.get('timeframe', timeframe),
|
||||||
|
seq_len=cfg.get('seq_len', 64),
|
||||||
|
min_confidence=cfg.get('min_confidence', 0.55),
|
||||||
|
tp_atr_mult=cfg.get('tp_atr_mult', 2.0),
|
||||||
|
sl_atr_mult=cfg.get('sl_atr_mult', 1.0),
|
||||||
|
horizon=cfg.get('horizon', 30),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reconstruire le modèle et charger les poids
|
||||||
|
instance.model = TradingCNN(n_features=5, n_classes=3)
|
||||||
|
instance.model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
instance.model.eval()
|
||||||
|
instance.is_trained = True
|
||||||
|
|
||||||
|
# Charger métadonnées si disponibles
|
||||||
|
if meta_path.exists():
|
||||||
|
try:
|
||||||
|
with open(meta_path) as f:
|
||||||
|
instance.metadata = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info(f"Modèle CNN chargé depuis {model_path}")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_trained_models(cls) -> List[Dict]:
|
||||||
|
"""Liste les modèles CNN entraînés disponibles."""
|
||||||
|
if not MODELS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for f in MODELS_DIR.glob("*_meta.json"):
|
||||||
|
try:
|
||||||
|
with open(f) as fp:
|
||||||
|
meta = json.load(fp)
|
||||||
|
models.append({
|
||||||
|
'symbol': meta.get('symbol', '?'),
|
||||||
|
'timeframe': meta.get('timeframe', '?'),
|
||||||
|
'model_type': 'cnn',
|
||||||
|
'trained_at': meta.get('trained_at', '?'),
|
||||||
|
'n_samples': meta.get('n_samples', 0),
|
||||||
|
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return models
|
||||||
|
|
||||||
|
def get_feature_importance(self, top_n: int = 10) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Pour CNN : pas de feature importance classique.
|
||||||
|
|
||||||
|
Retourne une liste vide — les CNN n'ont pas d'importance par feature
|
||||||
|
au sens des arbres de décision. Une analyse par gradient (GradCAM)
|
||||||
|
serait possible mais hors scope pour l'instant.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Walk-forward évaluation
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def _walk_forward_eval(self, X: np.ndarray, y: np.ndarray, n_folds: int = 2) -> Dict:
|
||||||
|
"""
|
||||||
|
Évalue le CNN en walk-forward validation temporelle.
|
||||||
|
|
||||||
|
Découpage : train 60%, test 20%, hold-out 20% (sur 2 folds).
|
||||||
|
"""
|
||||||
|
n = len(X)
|
||||||
|
fold_size = n // (n_folds + 1)
|
||||||
|
accuracies, precisions, recalls = [], [], []
|
||||||
|
|
||||||
|
for fold in range(n_folds):
|
||||||
|
train_end = fold_size * (fold + 1)
|
||||||
|
test_end = train_end + fold_size
|
||||||
|
|
||||||
|
if test_end > n:
|
||||||
|
break
|
||||||
|
|
||||||
|
X_tr, y_tr = X[:train_end], y[:train_end]
|
||||||
|
X_te, y_te = X[train_end:test_end], y[train_end:test_end]
|
||||||
|
|
||||||
|
if len(X_tr) < 30 or len(X_te) < 10:
|
||||||
|
logger.warning(f" Fold {fold + 1} ignoré : pas assez de données")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Entraîner un modèle temporaire
|
||||||
|
model = TradingCNN(n_features=5, n_classes=3)
|
||||||
|
class_weights = self._compute_class_weights(y_tr)
|
||||||
|
self._train_model(model, X_tr, y_tr, class_weights, max_epochs=50, patience=5)
|
||||||
|
|
||||||
|
# Évaluer
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
x_tensor = torch.FloatTensor(X_te)
|
||||||
|
logits = model(x_tensor)
|
||||||
|
y_pred = logits.argmax(dim=1).numpy()
|
||||||
|
|
||||||
|
acc = (y_pred == y_te).mean()
|
||||||
|
|
||||||
|
if SKLEARN_AVAILABLE:
|
||||||
|
prec, rec, _, _ = precision_recall_fscore_support(
|
||||||
|
y_te, y_pred, average='macro', zero_division=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prec, rec = 0.0, 0.0
|
||||||
|
|
||||||
|
accuracies.append(acc)
|
||||||
|
precisions.append(prec)
|
||||||
|
recalls.append(rec)
|
||||||
|
logger.info(f" Fold {fold + 1}/{n_folds} : acc={acc:.2%}, prec={prec:.2%}, rec={rec:.2%}")
|
||||||
|
|
||||||
|
if not accuracies:
|
||||||
|
return {'avg_accuracy': 0.0, 'avg_precision': 0.0, 'avg_recall': 0.0, 'fold_accuracies': []}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'avg_accuracy': float(np.mean(accuracies)),
|
||||||
|
'avg_precision': float(np.mean(precisions)),
|
||||||
|
'avg_recall': float(np.mean(recalls)),
|
||||||
|
'fold_accuracies': [float(a) for a in accuracies],
|
||||||
|
}
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Entraînement interne
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def _train_model(
|
||||||
|
self,
|
||||||
|
model: 'nn.Module',
|
||||||
|
X: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
class_weights: 'torch.Tensor' = None,
|
||||||
|
max_epochs: int = 100,
|
||||||
|
patience: int = 10,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
batch_size: int = 64,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Boucle d'entraînement PyTorch avec early stopping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Instance TradingCNN
|
||||||
|
X: Séquences (N, seq_len, 5)
|
||||||
|
y: Labels (N,) — indices 0, 1, 2
|
||||||
|
class_weights: Poids par classe pour CrossEntropy
|
||||||
|
max_epochs: Nombre max d'époques
|
||||||
|
patience: Nombre d'époques sans amélioration avant arrêt
|
||||||
|
lr: Learning rate
|
||||||
|
batch_size: Taille des batchs
|
||||||
|
"""
|
||||||
|
# Séparer un petit set de validation (derniers 15%)
|
||||||
|
val_size = max(int(len(X) * 0.15), 10)
|
||||||
|
X_train, X_val = X[:-val_size], X[-val_size:]
|
||||||
|
y_train, y_val = y[:-val_size], y[-val_size:]
|
||||||
|
|
||||||
|
# Tenseurs
|
||||||
|
X_train_t = torch.FloatTensor(X_train)
|
||||||
|
y_train_t = torch.LongTensor(y_train)
|
||||||
|
X_val_t = torch.FloatTensor(X_val)
|
||||||
|
y_val_t = torch.LongTensor(y_val)
|
||||||
|
|
||||||
|
dataset = TensorDataset(X_train_t, y_train_t)
|
||||||
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
|
# Loss avec poids de classe
|
||||||
|
if class_weights is not None:
|
||||||
|
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||||
|
else:
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
best_state = None
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for epoch in range(max_epochs):
|
||||||
|
# Phase entraînement
|
||||||
|
total_loss = 0.0
|
||||||
|
n_batches = 0
|
||||||
|
for x_batch, y_batch in loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logits = model(x_batch)
|
||||||
|
loss = criterion(logits, y_batch)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
# Phase validation
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
val_logits = model(X_val_t)
|
||||||
|
val_loss = criterion(val_logits, y_val_t).item()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
||||||
|
epochs_without_improvement = 0
|
||||||
|
else:
|
||||||
|
epochs_without_improvement += 1
|
||||||
|
|
||||||
|
if epochs_without_improvement >= patience:
|
||||||
|
logger.info(f" Early stopping à l'époque {epoch + 1} (patience={patience})")
|
||||||
|
break
|
||||||
|
|
||||||
|
if (epoch + 1) % 20 == 0:
|
||||||
|
avg_loss = total_loss / max(n_batches, 1)
|
||||||
|
logger.info(f" Époque {epoch + 1}/{max_epochs} — loss={avg_loss:.4f}, val_loss={val_loss:.4f}")
|
||||||
|
|
||||||
|
# Restaurer les meilleurs poids
|
||||||
|
if best_state is not None:
|
||||||
|
model.load_state_dict(best_state)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Utilitaires
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor':
|
||||||
|
"""
|
||||||
|
Calcule les poids inversement proportionnels à la fréquence des classes.
|
||||||
|
|
||||||
|
Permet de compenser le déséquilibre (ex: trop de NEUTRAL).
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
classes, counts = np.unique(y, return_counts=True)
|
||||||
|
n_samples = len(y)
|
||||||
|
n_classes = 3
|
||||||
|
|
||||||
|
weights = np.ones(n_classes, dtype=np.float32)
|
||||||
|
for cls, count in zip(classes, counts):
|
||||||
|
if cls < n_classes:
|
||||||
|
weights[int(cls)] = n_samples / (n_classes * count)
|
||||||
|
|
||||||
|
return torch.FloatTensor(weights)
|
||||||
11
src/ml/cnn_image/__init__.py
Normal file
11
src/ml/cnn_image/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Module CNN Image-Based — Analyse visuelle de graphiques en chandeliers.
|
||||||
|
|
||||||
|
Ce module implémente un CNN Conv2D qui analyse des graphiques de chandeliers
|
||||||
|
rendus comme de vraies images (128×128 RGB), pour reconnaître les patterns
|
||||||
|
visuels exactement comme un trader humain devant TradingView.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.ml.cnn_image.cnn_image_strategy_model import CNNImageStrategyModel
|
||||||
|
|
||||||
|
__all__ = ['CNNImageStrategyModel']
|
||||||
472
src/ml/cnn_image/chart_renderer.py
Normal file
472
src/ml/cnn_image/chart_renderer.py
Normal file
@@ -0,0 +1,472 @@
|
|||||||
|
"""
|
||||||
|
CandlestickImageRenderer — Convertit des données OHLCV en images de graphiques.
|
||||||
|
|
||||||
|
Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB
|
||||||
|
qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns.
|
||||||
|
|
||||||
|
Rendu (pur numpy, très rapide) :
|
||||||
|
- Fond noir (#0d1117), style TradingView
|
||||||
|
- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse
|
||||||
|
- Mèches (high/low), corps (open/close), volume en bas (20% de hauteur)
|
||||||
|
- Pas d'axes, pas de labels, pas de titre
|
||||||
|
- Taille fixe : 128×128 pixels, 3 canaux RGB
|
||||||
|
|
||||||
|
Perf : ~5 000 images/s (vs ~5/s avec mplfinance).
|
||||||
|
mplfinance reste utilisable via _render_with_mplfinance() pour l'affichage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Couleurs TradingView (normalisées [0, 1])
|
||||||
|
BG_COLOR = np.array([13 / 255, 17 / 255, 23 / 255], dtype=np.float32) # #0d1117
|
||||||
|
GREEN_COLOR = np.array([38 / 255, 166 / 255, 154 / 255], dtype=np.float32) # #26a69a
|
||||||
|
RED_COLOR = np.array([239 / 255, 83 / 255, 80 / 255], dtype=np.float32) # #ef5350
|
||||||
|
|
||||||
|
IMAGE_SIZE = 128
|
||||||
|
VOLUME_RATIO = 0.20 # 20% de la hauteur pour le volume
|
||||||
|
|
||||||
|
# --- Détection optionnelle de mplfinance (pour affichage uniquement) ---
|
||||||
|
try:
|
||||||
|
import mplfinance as mpf
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import io
|
||||||
|
MPLFINANCE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MPLFINANCE_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
PIL_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PIL_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class CandlestickImageRenderer:
|
||||||
|
"""
|
||||||
|
Convertit des données OHLCV en images de graphiques en chandeliers.
|
||||||
|
|
||||||
|
Chaque image est un instantané visuel de `seq_len` bougies consécutives,
|
||||||
|
rendu via pur numpy (rapide, ~5 000 images/s) ou mplfinance (lent, haute qualité).
|
||||||
|
Le résultat est normalisé en float32 dans [0, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size: Taille carrée de l'image en pixels (défaut 128)
|
||||||
|
use_mplfinance: Forcer mplfinance même pour encode() (lent, déconseillé)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_size: int = IMAGE_SIZE, use_mplfinance: bool = False):
|
||||||
|
self.image_size = image_size
|
||||||
|
self.use_mplfinance = use_mplfinance and MPLFINANCE_AVAILABLE and PIL_AVAILABLE
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Interface publique
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def encode(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode toutes les fenêtres glissantes du DataFrame en images.
|
||||||
|
|
||||||
|
Produit N = len(df) - seq_len fenêtres, chacune rendue en image
|
||||||
|
128×128 RGB normalisée [0, 1].
|
||||||
|
|
||||||
|
Utilise le renderer pur numpy (rapide) par défaut.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame OHLCV avec colonnes open/high/low/close/volume
|
||||||
|
seq_len: Nombre de bougies par fenêtre (défaut 64)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de forme (N, 3, 128, 128), dtype float32, valeurs [0, 1]
|
||||||
|
Retourne un tableau vide (0, 3, 128, 128) si df est trop court.
|
||||||
|
"""
|
||||||
|
df = self._prepare_df(df)
|
||||||
|
n_windows = len(df) - seq_len
|
||||||
|
|
||||||
|
if n_windows <= 0:
|
||||||
|
logger.warning(
|
||||||
|
f"DataFrame trop court ({len(df)} barres) pour seq_len={seq_len}"
|
||||||
|
)
|
||||||
|
return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32)
|
||||||
|
|
||||||
|
# Extraction des valeurs OHLCV en numpy (une seule conversion)
|
||||||
|
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
|
||||||
|
|
||||||
|
if self.use_mplfinance:
|
||||||
|
return self._encode_mplfinance(df, seq_len, n_windows)
|
||||||
|
else:
|
||||||
|
return self._encode_numpy(ohlcv, seq_len, n_windows)
|
||||||
|
|
||||||
|
def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode uniquement la dernière fenêtre du DataFrame.
|
||||||
|
|
||||||
|
Utilisé pour la prédiction en temps réel : retourne (1, 3, 128, 128).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame OHLCV
|
||||||
|
seq_len: Nombre de bougies pour la dernière fenêtre (défaut 64)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de forme (1, 3, 128, 128), dtype float32, valeurs [0, 1]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError si df est trop court pour seq_len bougies
|
||||||
|
"""
|
||||||
|
df = self._prepare_df(df)
|
||||||
|
|
||||||
|
if len(df) < seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"DataFrame insuffisant : {len(df)} barres < seq_len={seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
window = df.iloc[-seq_len:]
|
||||||
|
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
|
||||||
|
img = self._render_numpy(ohlcv)
|
||||||
|
return img[np.newaxis, ...] # (1, 3, H, W)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Renderer pur numpy (rapide)
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _encode_numpy(
|
||||||
|
self,
|
||||||
|
ohlcv: np.ndarray,
|
||||||
|
seq_len: int,
|
||||||
|
n_windows: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Encode toutes les fenêtres en batch vectorisé (sans boucle Python sur N).
|
||||||
|
|
||||||
|
Utilise sliding_window_view pour créer toutes les fenêtres d'un coup,
|
||||||
|
puis boucle sur les seq_len bougies (64) en traitant toutes les N fenêtres
|
||||||
|
simultanément. Numpy libère le GIL pendant les opérations vectorisées,
|
||||||
|
ce qui préserve la réactivité de l'event loop FastAPI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ohlcv: (T, 5) float32 [open, high, low, close, volume]
|
||||||
|
seq_len: Longueur de chaque fenêtre
|
||||||
|
n_windows: Nombre total de fenêtres à générer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(N, 3, H, W) float32
|
||||||
|
"""
|
||||||
|
H = W = self.image_size
|
||||||
|
|
||||||
|
# --- Toutes les fenêtres d'un coup : (N, seq_len, 5) ---
|
||||||
|
# sliding_window_view est O(1) en mémoire (vue sans copie)
|
||||||
|
windows = np.lib.stride_tricks.sliding_window_view(ohlcv, (seq_len, 5))
|
||||||
|
windows = windows[:n_windows, 0, :, :].astype(np.float32) # (N, seq_len, 5)
|
||||||
|
|
||||||
|
opens = windows[:, :, 0] # (N, seq_len)
|
||||||
|
highs = windows[:, :, 1]
|
||||||
|
lows = windows[:, :, 2]
|
||||||
|
closes = windows[:, :, 3]
|
||||||
|
vols = windows[:, :, 4]
|
||||||
|
|
||||||
|
# --- Fond de toutes les images ---
|
||||||
|
images = np.empty((n_windows, 3, H, W), dtype=np.float32)
|
||||||
|
images[:, 0, :, :] = BG_COLOR[0]
|
||||||
|
images[:, 1, :, :] = BG_COLOR[1]
|
||||||
|
images[:, 2, :, :] = BG_COLOR[2]
|
||||||
|
|
||||||
|
# --- Normalisation des prix par fenêtre ---
|
||||||
|
price_min = lows.min(axis=1, keepdims=True) # (N, 1)
|
||||||
|
price_max = highs.max(axis=1, keepdims=True) # (N, 1)
|
||||||
|
price_rng = np.maximum(price_max - price_min, 1e-8)
|
||||||
|
|
||||||
|
vol_h = max(1, int(H * VOLUME_RATIO))
|
||||||
|
chart_h = H - vol_h
|
||||||
|
|
||||||
|
def to_row(prices: np.ndarray) -> np.ndarray:
|
||||||
|
"""Convertit prix → rangée pixel (0 = haut, chart_h-1 = bas)."""
|
||||||
|
norm = (prices - price_min) / price_rng
|
||||||
|
rows = ((1.0 - norm) * (chart_h - 1)).astype(np.int32)
|
||||||
|
return np.clip(rows, 0, chart_h - 1)
|
||||||
|
|
||||||
|
row_h = to_row(highs) # (N, seq_len)
|
||||||
|
row_l = to_row(lows)
|
||||||
|
row_o = to_row(opens)
|
||||||
|
row_c = to_row(closes)
|
||||||
|
|
||||||
|
body_top = np.minimum(row_o, row_c) # (N, seq_len)
|
||||||
|
body_bot = np.maximum(row_o, row_c)
|
||||||
|
|
||||||
|
is_bull = (closes >= opens) # (N, seq_len) bool
|
||||||
|
|
||||||
|
# Volume normalisé [0, 1] par fenêtre
|
||||||
|
vol_max = np.maximum(vols.max(axis=1, keepdims=True), 1e-8)
|
||||||
|
vol_norm = vols / vol_max # (N, seq_len)
|
||||||
|
|
||||||
|
# Positions X des bougies
|
||||||
|
candle_w = max(1, W // seq_len)
|
||||||
|
half_w = max(1, candle_w // 2)
|
||||||
|
x_centers = ((np.arange(seq_len) + 0.5) * W / seq_len).astype(np.int32)
|
||||||
|
|
||||||
|
row_idx = np.arange(chart_h) # (chart_h,) pour les masques
|
||||||
|
vol_idx = np.arange(vol_h) # (vol_h,)
|
||||||
|
|
||||||
|
# --- Boucle sur les bougies (64 itérations, GIL libéré entre chaque) ---
|
||||||
|
for i in range(seq_len):
|
||||||
|
x_c = int(x_centers[i])
|
||||||
|
x0 = max(0, x_c - half_w)
|
||||||
|
x1 = min(W, x_c + half_w + 1)
|
||||||
|
wick_x = min(x_c, W - 1)
|
||||||
|
|
||||||
|
rh = row_h[:, i] # (N,)
|
||||||
|
rl = row_l[:, i]
|
||||||
|
bt = body_top[:, i]
|
||||||
|
bb = body_bot[:, i]
|
||||||
|
bull = is_bull[:, i] # (N,) bool
|
||||||
|
|
||||||
|
# Masques vectorisés sur toutes les N fenêtres
|
||||||
|
# wick_mask[w, r] = True si rh[w] <= r <= rl[w]
|
||||||
|
wick_mask = (
|
||||||
|
(row_idx[None, :] >= rh[:, None]) &
|
||||||
|
(row_idx[None, :] <= rl[:, None])
|
||||||
|
) # (N, chart_h)
|
||||||
|
|
||||||
|
body_mask = (
|
||||||
|
(row_idx[None, :] >= bt[:, None]) &
|
||||||
|
(row_idx[None, :] <= bb[:, None])
|
||||||
|
) # (N, chart_h)
|
||||||
|
|
||||||
|
# Volume : barre du bas
|
||||||
|
vol_bar_h = (vol_norm[:, i] * vol_h).astype(np.int32)
|
||||||
|
vol_thresh = np.maximum(vol_h - vol_bar_h, 0)
|
||||||
|
vol_mask = vol_idx[None, :] >= vol_thresh[:, None] # (N, vol_h)
|
||||||
|
|
||||||
|
for c in range(3):
|
||||||
|
g = float(GREEN_COLOR[c])
|
||||||
|
r = float(RED_COLOR[c])
|
||||||
|
# Couleur par fenêtre : vert si haussière, rouge sinon
|
||||||
|
color_w = np.where(bull, g, r).astype(np.float32) # (N,)
|
||||||
|
|
||||||
|
# Mèche
|
||||||
|
images[:, c, :chart_h, wick_x] = np.where(
|
||||||
|
wick_mask,
|
||||||
|
color_w[:, None],
|
||||||
|
images[:, c, :chart_h, wick_x],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Corps
|
||||||
|
images[:, c, :chart_h, x0:x1] = np.where(
|
||||||
|
body_mask[:, :, None],
|
||||||
|
color_w[:, None, None],
|
||||||
|
images[:, c, :chart_h, x0:x1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Volume (opacité 60%)
|
||||||
|
images[:, c, H - vol_h:H, x0:x1] = np.where(
|
||||||
|
vol_mask[:, :, None],
|
||||||
|
color_w[:, None, None] * 0.6,
|
||||||
|
images[:, c, H - vol_h:H, x0:x1],
|
||||||
|
)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
def _render_numpy(self, ohlcv: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rend une seule fenêtre OHLCV en image (3, H, W) via pur numpy.
|
||||||
|
|
||||||
|
Dessin :
|
||||||
|
- Fond #0d1117
|
||||||
|
- Mèche : ligne verticale high→low (1 pixel de large)
|
||||||
|
- Corps : rectangle open→close (largeur proportionnelle)
|
||||||
|
- Volume : barre en bas (20% hauteur), opacité 60%
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ohlcv: (N, 5) float32 [open, high, low, close, volume]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(3, image_size, image_size) float32 [0, 1]
|
||||||
|
"""
|
||||||
|
H = W = self.image_size
|
||||||
|
n = len(ohlcv)
|
||||||
|
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros((3, H, W), dtype=np.float32)
|
||||||
|
|
||||||
|
# --- Fond ---
|
||||||
|
img = np.empty((3, H, W), dtype=np.float32)
|
||||||
|
for c in range(3):
|
||||||
|
img[c] = BG_COLOR[c]
|
||||||
|
|
||||||
|
opens = ohlcv[:, 0]
|
||||||
|
highs = ohlcv[:, 1]
|
||||||
|
lows = ohlcv[:, 2]
|
||||||
|
closes = ohlcv[:, 3]
|
||||||
|
vols = ohlcv[:, 4]
|
||||||
|
|
||||||
|
# --- Normalisation des prix ---
|
||||||
|
price_min = lows.min()
|
||||||
|
price_max = highs.max()
|
||||||
|
if price_max <= price_min:
|
||||||
|
return img # données dégénérées
|
||||||
|
|
||||||
|
vol_h = max(1, int(H * VOLUME_RATIO))
|
||||||
|
chart_h = H - vol_h # hauteur zone prix
|
||||||
|
|
||||||
|
def price_to_row(price: float) -> int:
|
||||||
|
"""Convertit un prix en ligne pixel (0 = haut, chart_h-1 = bas)."""
|
||||||
|
norm = (price - price_min) / (price_max - price_min)
|
||||||
|
row = int((1.0 - norm) * (chart_h - 1))
|
||||||
|
return max(0, min(chart_h - 1, row))
|
||||||
|
|
||||||
|
# --- Largeur des bougies ---
|
||||||
|
candle_w = max(1, W // n)
|
||||||
|
half_w = max(1, candle_w // 2)
|
||||||
|
|
||||||
|
# --- Volume ---
|
||||||
|
vol_max = vols.max()
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
x_c = int((i + 0.5) * W / n)
|
||||||
|
x0 = max(0, x_c - half_w)
|
||||||
|
x1 = min(W, x_c + half_w + 1)
|
||||||
|
|
||||||
|
is_bull = closes[i] >= opens[i]
|
||||||
|
color = GREEN_COLOR if is_bull else RED_COLOR
|
||||||
|
|
||||||
|
row_h = price_to_row(highs[i])
|
||||||
|
row_l = price_to_row(lows[i])
|
||||||
|
row_o = price_to_row(opens[i])
|
||||||
|
row_c = price_to_row(closes[i])
|
||||||
|
|
||||||
|
body_top = min(row_o, row_c)
|
||||||
|
body_bot = max(row_o, row_c)
|
||||||
|
|
||||||
|
# Mèche (1 pixel de large, centré)
|
||||||
|
wick_x = min(x_c, W - 1)
|
||||||
|
for c in range(3):
|
||||||
|
img[c, row_h: row_l + 1, wick_x] = color[c]
|
||||||
|
|
||||||
|
# Corps
|
||||||
|
if body_bot >= body_top:
|
||||||
|
for c in range(3):
|
||||||
|
img[c, body_top: body_bot + 1, x0: x1] = color[c]
|
||||||
|
|
||||||
|
# Volume (bas de l'image)
|
||||||
|
if vol_max > 0:
|
||||||
|
bar_h = int((vols[i] / vol_max) * vol_h)
|
||||||
|
if bar_h > 0:
|
||||||
|
row_v0 = H - bar_h
|
||||||
|
for c in range(3):
|
||||||
|
img[c, row_v0: H, x0: x1] = color[c] * 0.6
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Renderer mplfinance (lent, haute qualité, pour affichage uniquement)
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _encode_mplfinance(
|
||||||
|
self,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
seq_len: int,
|
||||||
|
n_windows: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Encode via mplfinance (lent — ne pas utiliser pour l'entraînement)."""
|
||||||
|
H = W = self.image_size
|
||||||
|
images = np.zeros((n_windows, 3, H, W), dtype=np.float32)
|
||||||
|
import warnings
|
||||||
|
for i in range(n_windows):
|
||||||
|
window = df.iloc[i: i + seq_len]
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
images[i] = self._render_with_mplfinance(window)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Erreur mplfinance fenêtre {i} : {e}")
|
||||||
|
return images
|
||||||
|
|
||||||
|
def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Rendu haute qualité via mplfinance (pour affichage / debug).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError si mplfinance ou PIL non disponible.
|
||||||
|
"""
|
||||||
|
if not MPLFINANCE_AVAILABLE or not PIL_AVAILABLE:
|
||||||
|
raise ImportError("mplfinance ou PIL non disponible")
|
||||||
|
|
||||||
|
style = mpf.make_mpf_style(
|
||||||
|
base_mpf_style='nightclouds',
|
||||||
|
marketcolors=mpf.make_marketcolors(
|
||||||
|
up='#26a69a', down='#ef5350',
|
||||||
|
wick={'up': '#26a69a', 'down': '#ef5350'},
|
||||||
|
edge={'up': '#26a69a', 'down': '#ef5350'},
|
||||||
|
volume={'up': '#26a69a', 'down': '#ef5350'},
|
||||||
|
),
|
||||||
|
facecolor='#0d1117',
|
||||||
|
figcolor='#0d1117',
|
||||||
|
gridcolor='#0d1117',
|
||||||
|
)
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
import warnings
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
fig, axes = mpf.plot(
|
||||||
|
df_window,
|
||||||
|
type='candle',
|
||||||
|
style=style,
|
||||||
|
volume=True,
|
||||||
|
axisoff=True,
|
||||||
|
tight_layout=True,
|
||||||
|
returnfig=True,
|
||||||
|
figsize=(1.28, 1.28),
|
||||||
|
)
|
||||||
|
|
||||||
|
for ax in axes:
|
||||||
|
ax.set_axis_off()
|
||||||
|
ax.set_facecolor('#0d1117')
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_visible(False)
|
||||||
|
|
||||||
|
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight',
|
||||||
|
pad_inches=0, facecolor='#0d1117')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
buf.seek(0)
|
||||||
|
pil_img = Image.open(buf).convert('RGB')
|
||||||
|
pil_img = pil_img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||||
|
|
||||||
|
arr = np.array(pil_img, dtype=np.float32) / 255.0
|
||||||
|
return arr.transpose(2, 0, 1) # HWC → CHW
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Utilitaires
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex.
|
||||||
|
"""
|
||||||
|
df = df.copy()
|
||||||
|
df.columns = [c.lower() for c in df.columns]
|
||||||
|
|
||||||
|
if not isinstance(df.index, pd.DatetimeIndex):
|
||||||
|
try:
|
||||||
|
df.index = pd.to_datetime(df.index)
|
||||||
|
except Exception:
|
||||||
|
df.index = pd.date_range(
|
||||||
|
start='2020-01-01', periods=len(df), freq='1h'
|
||||||
|
)
|
||||||
|
|
||||||
|
required = ['open', 'high', 'low', 'close', 'volume']
|
||||||
|
for col in required:
|
||||||
|
if col not in df.columns:
|
||||||
|
df[col] = 0.0
|
||||||
|
|
||||||
|
df = df[required].dropna(subset=['open', 'high', 'low', 'close'])
|
||||||
|
df = df.ffill().bfill()
|
||||||
|
return df
|
||||||
163
src/ml/cnn_image/cnn_image_model.py
Normal file
163
src/ml/cnn_image/cnn_image_model.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""
|
||||||
|
CandlestickCNN — Réseau de neurones convolutif pour l'analyse visuelle de graphiques.
|
||||||
|
|
||||||
|
Architecture Conv2D 4-blocs conçue pour reconnaître les patterns visuels dans
|
||||||
|
des images de chandeliers 128×128 RGB (bougies, volumes, supports/résistances).
|
||||||
|
|
||||||
|
Input : (batch, 3, 128, 128) — images RGB normalisées [0, 1]
|
||||||
|
Output : (batch, 3) — logits pour 3 classes : SHORT(0), NEUTRAL(1), LONG(2)
|
||||||
|
|
||||||
|
Classes encodées :
|
||||||
|
0 → SHORT (-1)
|
||||||
|
1 → NEUTRAL ( 0)
|
||||||
|
2 → LONG (+1)
|
||||||
|
|
||||||
|
L'encodage +1 est cohérent avec MLStrategyModel et CNNStrategyModel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# --- Détection optionnelle de PyTorch ---
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
logger.debug("PyTorch disponible — CandlestickCNN activé")
|
||||||
|
except ImportError:
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
# Stubs pour permettre l'import du module sans PyTorch
|
||||||
|
nn = None
|
||||||
|
torch = None
|
||||||
|
logger.warning("PyTorch non disponible — CandlestickCNN désactivé")
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Définition conditionnelle du modèle
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
if TORCH_AVAILABLE:
|
||||||
|
|
||||||
|
class CandlestickCNN(nn.Module):
|
||||||
|
"""
|
||||||
|
CNN vision pour graphiques chandeliers 128×128 RGB.
|
||||||
|
|
||||||
|
Architecture :
|
||||||
|
Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm2d(32) + ReLU + MaxPool2d(2) → (32, 64, 64)
|
||||||
|
Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm2d(64) + ReLU + MaxPool2d(2) → (64, 32, 32)
|
||||||
|
Bloc 3 : Conv2d(64→128,3×3) + BatchNorm2d(128) + ReLU + MaxPool2d(2) → (128, 16, 16)
|
||||||
|
Bloc 4 : Conv2d(128→256,3×3)+ BatchNorm2d(256) + ReLU + AdaptiveAvgPool2d(1) → (256,)
|
||||||
|
|
||||||
|
Classifieur :
|
||||||
|
Linear(256→128) + Dropout(0.4) + ReLU
|
||||||
|
Linear(128→3)
|
||||||
|
|
||||||
|
Output : logits bruts (3 classes) — appliquer softmax pour les probabilités.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes: Nombre de classes de sortie (défaut 3 : SHORT/NEUTRAL/LONG)
|
||||||
|
dropout: Taux de dropout dans le classifieur (défaut 0.4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
|
||||||
|
super(CandlestickCNN, self).__init__()
|
||||||
|
|
||||||
|
# --- Bloc convolutif 1 : (3, 128, 128) → (32, 64, 64) ---
|
||||||
|
self.bloc1 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(32),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Bloc convolutif 2 : (32, 64, 64) → (64, 32, 32) ---
|
||||||
|
self.bloc2 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Bloc convolutif 3 : (64, 32, 32) → (128, 16, 16) ---
|
||||||
|
self.bloc3 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Bloc convolutif 4 : (128, 16, 16) → (256, 1, 1) → (256,) ---
|
||||||
|
self.bloc4 = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.AdaptiveAvgPool2d(output_size=1), # Global average pooling → (256, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Classifieur dense ---
|
||||||
|
self.classifieur = nn.Sequential(
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(128, n_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||||
|
"""
|
||||||
|
Passe avant du réseau.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor (batch, 3) — logits bruts
|
||||||
|
"""
|
||||||
|
x = self.bloc1(x) # (batch, 32, 64, 64)
|
||||||
|
x = self.bloc2(x) # (batch, 64, 32, 32)
|
||||||
|
x = self.bloc3(x) # (batch, 128, 16, 16)
|
||||||
|
x = self.bloc4(x) # (batch, 256, 1, 1)
|
||||||
|
x = x.flatten(1) # (batch, 256)
|
||||||
|
x = self.classifieur(x) # (batch, 3)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def predict_proba(self, x: 'torch.Tensor') -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Calcule les probabilités softmax pour chaque classe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de forme (batch, 3), valeurs dans [0, 1], somme = 1
|
||||||
|
Ordre des colonnes : [P_SHORT, P_NEUTRAL, P_LONG]
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.forward(x) # (batch, 3)
|
||||||
|
probas = F.softmax(logits, dim=1) # (batch, 3)
|
||||||
|
return probas.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Stub si PyTorch n'est pas disponible
|
||||||
|
class CandlestickCNN:
|
||||||
|
"""
|
||||||
|
Stub de CandlestickCNN — PyTorch non disponible.
|
||||||
|
|
||||||
|
Toutes les méthodes lèvent une RuntimeError explicite.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
|
||||||
|
raise RuntimeError(
|
||||||
|
"PyTorch non disponible. Installer torch>=2.0.0 pour utiliser CandlestickCNN."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
raise RuntimeError("PyTorch non disponible.")
|
||||||
|
|
||||||
|
def predict_proba(self, x):
|
||||||
|
raise RuntimeError("PyTorch non disponible.")
|
||||||
657
src/ml/cnn_image/cnn_image_strategy_model.py
Normal file
657
src/ml/cnn_image/cnn_image_strategy_model.py
Normal file
@@ -0,0 +1,657 @@
|
|||||||
|
"""
|
||||||
|
CNNImageStrategyModel — Modèle CNN image-based pour la prédiction de signaux de trading.
|
||||||
|
|
||||||
|
Ce module entraîne un CNN Conv2D (CandlestickCNN) sur des images de graphiques
|
||||||
|
en chandeliers pour prédire LONG (+1), NEUTRAL (0) ou SHORT (-1).
|
||||||
|
|
||||||
|
Pipeline :
|
||||||
|
1. Génération de labels ATR-based (LabelGenerator partagé)
|
||||||
|
2. Rendu des fenêtres OHLCV en images 128×128 RGB (CandlestickImageRenderer)
|
||||||
|
3. Walk-forward (2 folds temporels) pour évaluation out-of-sample
|
||||||
|
4. Entraînement sur toutes les données (Adam, CrossEntropyLoss weighted)
|
||||||
|
5. Sauvegarde : EURUSD_1h.pt (state_dict) + EURUSD_1h_meta.json
|
||||||
|
|
||||||
|
Interface identique à MLStrategyModel et CNNStrategyModel :
|
||||||
|
model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||||
|
result = model.train(df_ohlcv)
|
||||||
|
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||||
|
model.save()
|
||||||
|
model.load(symbol, timeframe)
|
||||||
|
model.list_trained_models()
|
||||||
|
model.get_feature_importance() # retourne [] (CNN = boîte noire)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.ml.features.label_generator import LabelGenerator
|
||||||
|
from src.ml.cnn_image.chart_renderer import CandlestickImageRenderer, MPLFINANCE_AVAILABLE
|
||||||
|
from src.ml.cnn_image.cnn_image_model import CandlestickCNN, TORCH_AVAILABLE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Répertoire de sauvegarde des modèles CNN image
|
||||||
|
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_image_strategy"
|
||||||
|
|
||||||
|
# Décodage des classes encodées vers les signaux de trading
|
||||||
|
# Classe 0 → SHORT (-1), Classe 1 → NEUTRAL (0), Classe 2 → LONG (+1)
|
||||||
|
CLASS_MAP = {0: -1, 1: 0, 2: 1}
|
||||||
|
|
||||||
|
# Imports conditionnels PyTorch
|
||||||
|
if TORCH_AVAILABLE:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
from sklearn.model_selection import TimeSeriesSplit
|
||||||
|
from sklearn.metrics import precision_recall_fscore_support
|
||||||
|
SKLEARN_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
torch = None
|
||||||
|
nn = None
|
||||||
|
DataLoader = None
|
||||||
|
TensorDataset = None
|
||||||
|
TimeSeriesSplit = None
|
||||||
|
precision_recall_fscore_support = None
|
||||||
|
SKLEARN_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sklearn.model_selection import TimeSeriesSplit
|
||||||
|
from sklearn.metrics import precision_recall_fscore_support
|
||||||
|
SKLEARN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
SKLEARN_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class CNNImageStrategyModel:
|
||||||
|
"""
|
||||||
|
Modèle CNN image-based qui reconnaît les patterns visuels de trading.
|
||||||
|
|
||||||
|
Le modèle :
|
||||||
|
- Convertit des bougies OHLCV en images 128×128 RGB (mplfinance)
|
||||||
|
- Reconnaît visuellement : marteaux, doji, double top/bottom, supports/résistances
|
||||||
|
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||||
|
- Donne un score de confiance [0..1] par prédiction
|
||||||
|
- Se sauvegarde en format PyTorch (.pt) pour rechargement sans ré-entraînement
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire tradée (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h', '15m')
|
||||||
|
seq_len: Nombre de bougies par image (défaut 64)
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour le TP (génération labels)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour le SL (génération labels)
|
||||||
|
horizon: Nombre de barres pour évaluer TP/SL
|
||||||
|
min_confidence: Seuil de confiance pour le signal tradeable
|
||||||
|
epochs: Nombre maximum d'époques d'entraînement
|
||||||
|
batch_size: Taille du batch (défaut 32)
|
||||||
|
lr: Taux d'apprentissage Adam (défaut 1e-3)
|
||||||
|
patience: Patience pour l'early stopping (défaut 7)
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS_DIR = MODELS_DIR
|
||||||
|
CLASS_MAP = CLASS_MAP
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
symbol: str = 'EURUSD',
|
||||||
|
timeframe: str = '1h',
|
||||||
|
seq_len: int = 64,
|
||||||
|
tp_atr_mult: float = 2.0,
|
||||||
|
sl_atr_mult: float = 1.0,
|
||||||
|
horizon: int = 30,
|
||||||
|
min_confidence: float = 0.55,
|
||||||
|
epochs: int = 50,
|
||||||
|
batch_size: int = 32,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
patience: int = 7,
|
||||||
|
):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.timeframe = timeframe
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.tp_atr_mult = tp_atr_mult
|
||||||
|
self.sl_atr_mult = sl_atr_mult
|
||||||
|
self.horizon = horizon
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
self.epochs = epochs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.lr = lr
|
||||||
|
self.patience = patience
|
||||||
|
|
||||||
|
self.model: Optional['CandlestickCNN'] = None
|
||||||
|
self.is_trained = False
|
||||||
|
self.metadata: Dict = {}
|
||||||
|
|
||||||
|
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Entraînement
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def train(self, data: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Entraîne le modèle CNN image sur les données OHLCV fournies.
|
||||||
|
|
||||||
|
Étapes :
|
||||||
|
1. Génération des labels ATR-based (LabelGenerator)
|
||||||
|
2. Encodage des fenêtres en images (CandlestickImageRenderer)
|
||||||
|
3. Walk-forward évaluation (2 folds temporels)
|
||||||
|
4. Entraînement final sur toutes les données
|
||||||
|
5. Sauvegarde automatique
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame OHLCV (minimum 200 barres recommandées,
|
||||||
|
colonnes : open/high/low/close/volume)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict avec métriques : wf_accuracy, wf_precision, label_dist,
|
||||||
|
n_samples, trained_at, et error si échec.
|
||||||
|
"""
|
||||||
|
# Vérifications de disponibilité
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||||
|
if not SKLEARN_AVAILABLE:
|
||||||
|
return {'error': 'scikit-learn non disponible'}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Début entraînement CNNImageStrategyModel pour "
|
||||||
|
f"{self.symbol}/{self.timeframe} — seq_len={self.seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Génération des labels (ATR-based, cohérent avec MLStrategyModel)
|
||||||
|
gen = LabelGenerator(horizon=self.horizon)
|
||||||
|
labels = gen.generate_atr_based(
|
||||||
|
data,
|
||||||
|
atr_tp_mult=self.tp_atr_mult,
|
||||||
|
atr_sl_mult=self.sl_atr_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encodage [-1,0,1] → [0,1,2] pour CrossEntropyLoss
|
||||||
|
y_enc = (labels + 1).values # np.ndarray
|
||||||
|
|
||||||
|
# 2. Encodage des images (fenêtres glissantes)
|
||||||
|
logger.info(f" Rendu des images (seq_len={self.seq_len})...")
|
||||||
|
renderer = CandlestickImageRenderer()
|
||||||
|
X = renderer.encode(data, seq_len=self.seq_len) # (N, 3, 128, 128)
|
||||||
|
|
||||||
|
# 3. Alignement X et y
|
||||||
|
# encode() produit N = len(data) - seq_len images
|
||||||
|
# La i-ème image correspond à data.iloc[i : i + seq_len]
|
||||||
|
# Le label associé est y_enc[i + seq_len - 1] (dernier point de la fenêtre)
|
||||||
|
n_images = len(X)
|
||||||
|
if n_images == 0:
|
||||||
|
return {
|
||||||
|
'error': f'Pas assez de données : {len(data)} barres < seq_len={self.seq_len}'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Index des labels pour chaque image : fenêtre i → label à l'indice i + seq_len - 1
|
||||||
|
label_indices = np.arange(self.seq_len - 1, self.seq_len - 1 + n_images)
|
||||||
|
# Supprimer les barres de fin (labels non fiables, horizon tronqué)
|
||||||
|
valid_mask = label_indices < (len(y_enc) - self.horizon)
|
||||||
|
|
||||||
|
X = X[valid_mask]
|
||||||
|
y = y_enc[label_indices[valid_mask]]
|
||||||
|
|
||||||
|
n_samples = len(X)
|
||||||
|
if n_samples < 50:
|
||||||
|
return {'error': f'Trop peu d\'échantillons valides : {n_samples}'}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" {n_samples} images valides — "
|
||||||
|
f"LONG={(y==2).sum()}, SHORT={(y==0).sum()}, NEUTRAL={(y==1).sum()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Calcul des class_weights (inversement proportionnels à la fréquence)
|
||||||
|
class_weights = self._compute_class_weights(y)
|
||||||
|
|
||||||
|
# 5. Walk-forward évaluation (2 folds)
|
||||||
|
wf_metrics = self._walk_forward_eval(X, y, class_weights, n_splits=2)
|
||||||
|
|
||||||
|
# 6. Entraînement final sur toutes les données
|
||||||
|
logger.info(" Entraînement final sur toutes les données...")
|
||||||
|
self.model = CandlestickCNN(n_classes=3)
|
||||||
|
self._train_model(self.model, X, y, class_weights)
|
||||||
|
self.is_trained = True
|
||||||
|
|
||||||
|
# Distribution des labels (décodage pour lisibilité)
|
||||||
|
label_dist = {
|
||||||
|
'long': int((y == 2).sum()),
|
||||||
|
'short': int((y == 0).sum()),
|
||||||
|
'neutral': int((y == 1).sum()),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.metadata = {
|
||||||
|
'symbol': self.symbol,
|
||||||
|
'timeframe': self.timeframe,
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'trained_at': datetime.utcnow().isoformat(),
|
||||||
|
'n_samples': n_samples,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'horizon': self.horizon,
|
||||||
|
'label_dist': label_dist,
|
||||||
|
'wf_metrics': wf_metrics,
|
||||||
|
'mplfinance': MPLFINANCE_AVAILABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 7. Sauvegarde
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Entraînement terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}"
|
||||||
|
)
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Prédiction
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def predict(self, data: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Prédit le signal pour la dernière fenêtre disponible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame OHLCV récent (minimum seq_len barres)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict : {
|
||||||
|
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
|
||||||
|
'confidence': float [0..1] — max(P_LONG, P_SHORT),
|
||||||
|
'probas': {'long': float, 'short': float, 'neutral': float},
|
||||||
|
'tradeable': bool — confidence >= min_confidence et signal != 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': 'PyTorch non disponible'
|
||||||
|
}
|
||||||
|
if not self.is_trained or self.model is None:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': 'Modèle non entraîné'
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
renderer = CandlestickImageRenderer()
|
||||||
|
img = renderer.encode_last(data, seq_len=self.seq_len) # (1, 3, 128, 128)
|
||||||
|
except ValueError as e:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Conversion en Tensor
|
||||||
|
x_tensor = torch.from_numpy(img).float() # (1, 3, 128, 128)
|
||||||
|
|
||||||
|
# Prédiction
|
||||||
|
proba_arr = self.model.predict_proba(x_tensor) # (1, 3)
|
||||||
|
proba = proba_arr[0] # (3,) : [P_SHORT, P_NEUTRAL, P_LONG]
|
||||||
|
|
||||||
|
pred_class = int(np.argmax(proba))
|
||||||
|
signal = self.CLASS_MAP[pred_class] # Décodage : 0→-1, 1→0, 2→1
|
||||||
|
|
||||||
|
probas = {
|
||||||
|
'short': float(proba[0]),
|
||||||
|
'neutral': float(proba[1]),
|
||||||
|
'long': float(proba[2]),
|
||||||
|
}
|
||||||
|
confidence = float(max(probas['long'], probas['short']))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'signal': signal,
|
||||||
|
'confidence': confidence,
|
||||||
|
'probas': probas,
|
||||||
|
'tradeable': confidence >= self.min_confidence and signal != 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Sauvegarde / Chargement
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save(self) -> Path:
|
||||||
|
"""
|
||||||
|
Sauvegarde le modèle et ses métadonnées sur disque.
|
||||||
|
|
||||||
|
Format :
|
||||||
|
{symbol}_{timeframe}.pt — state_dict PyTorch
|
||||||
|
{symbol}_{timeframe}_meta.json — métadonnées JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path vers le fichier .pt sauvegardé
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError si le modèle n'est pas entraîné ou PyTorch indisponible
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise RuntimeError("PyTorch non disponible")
|
||||||
|
if not self.is_trained or self.model is None:
|
||||||
|
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||||
|
|
||||||
|
model_id = f"{self.symbol}_{self.timeframe}"
|
||||||
|
model_path = self.MODELS_DIR / f"{model_id}.pt"
|
||||||
|
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
|
||||||
|
|
||||||
|
# Sauvegarde du state_dict + config pour pouvoir reconstruire le modèle
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
'state_dict': self.model.state_dict(),
|
||||||
|
'config': {
|
||||||
|
'symbol': self.symbol,
|
||||||
|
'timeframe': self.timeframe,
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'horizon': self.horizon,
|
||||||
|
'min_confidence': self.min_confidence,
|
||||||
|
'epochs': self.epochs,
|
||||||
|
'batch_size': self.batch_size,
|
||||||
|
'lr': self.lr,
|
||||||
|
'patience': self.patience,
|
||||||
|
},
|
||||||
|
'metadata': self.metadata,
|
||||||
|
},
|
||||||
|
model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
json.dump(self.metadata, f, indent=2, default=str)
|
||||||
|
|
||||||
|
logger.info(f"Modèle CNN image sauvegardé : {model_path}")
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, symbol: str, timeframe: str) -> 'CNNImageStrategyModel':
|
||||||
|
"""
|
||||||
|
Charge un modèle existant depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance CNNImageStrategyModel prête à prédire
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError si PyTorch non disponible
|
||||||
|
FileNotFoundError si le modèle n'existe pas
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise RuntimeError("PyTorch non disponible")
|
||||||
|
|
||||||
|
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Modèle non trouvé : {model_path}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu')
|
||||||
|
cfg = checkpoint.get('config', {})
|
||||||
|
|
||||||
|
instance = cls(
|
||||||
|
symbol = cfg.get('symbol', symbol),
|
||||||
|
timeframe = cfg.get('timeframe', timeframe),
|
||||||
|
seq_len = cfg.get('seq_len', 64),
|
||||||
|
tp_atr_mult = cfg.get('tp_atr_mult', 2.0),
|
||||||
|
sl_atr_mult = cfg.get('sl_atr_mult', 1.0),
|
||||||
|
horizon = cfg.get('horizon', 30),
|
||||||
|
min_confidence = cfg.get('min_confidence', 0.55),
|
||||||
|
epochs = cfg.get('epochs', 50),
|
||||||
|
batch_size = cfg.get('batch_size', 32),
|
||||||
|
lr = cfg.get('lr', 1e-3),
|
||||||
|
patience = cfg.get('patience', 7),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reconstruction du modèle et chargement des poids
|
||||||
|
instance.model = CandlestickCNN(n_classes=3)
|
||||||
|
instance.model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
instance.model.eval()
|
||||||
|
instance.is_trained = True
|
||||||
|
instance.metadata = checkpoint.get('metadata', {})
|
||||||
|
|
||||||
|
logger.info(f"Modèle CNN image chargé depuis {model_path}")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_trained_models() -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Retourne la liste des modèles entraînés disponibles.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Liste de dicts avec symbol, timeframe, trained_at, n_samples, wf_accuracy
|
||||||
|
"""
|
||||||
|
if not MODELS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for f in MODELS_DIR.glob("*_meta.json"):
|
||||||
|
try:
|
||||||
|
with open(f) as fp:
|
||||||
|
meta = json.load(fp)
|
||||||
|
# Nom du fichier : EURUSD_1h_meta.json → EURUSD, 1h
|
||||||
|
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
|
||||||
|
parts = stem.split('_', 1)
|
||||||
|
models.append({
|
||||||
|
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||||
|
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||||
|
'trained_at': meta.get('trained_at', '?'),
|
||||||
|
'n_samples': meta.get('n_samples', 0),
|
||||||
|
'seq_len': meta.get('seq_len', 64),
|
||||||
|
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
|
||||||
|
'mplfinance': meta.get('mplfinance', False),
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def get_feature_importance(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Retourne l'importance des features.
|
||||||
|
|
||||||
|
Note : Les CNN sont des boîtes noires — pas d'importance de feature
|
||||||
|
interprétable comme XGBoost. Retourne une liste vide.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[] — CNN = boîte noire visuelle
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Entraînement interne
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _train_model(
|
||||||
|
self,
|
||||||
|
model: 'CandlestickCNN',
|
||||||
|
X: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
class_weights: 'torch.Tensor',
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Entraîne le modèle CNN sur les données fournies avec early stopping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Instance CandlestickCNN à entraîner
|
||||||
|
X: Images (N, 3, 128, 128) float32
|
||||||
|
y: Labels encodés (N,) int, valeurs {0, 1, 2}
|
||||||
|
class_weights: Poids de classes (3,) pour CrossEntropyLoss
|
||||||
|
"""
|
||||||
|
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
class_weights = class_weights.to(device)
|
||||||
|
|
||||||
|
# Conversion en Tensors
|
||||||
|
X_t = torch.from_numpy(X).float()
|
||||||
|
y_t = torch.from_numpy(y.astype(np.int64))
|
||||||
|
|
||||||
|
# Split train/validation (80/20 temporel)
|
||||||
|
n_train = int(len(X_t) * 0.8)
|
||||||
|
X_tr, X_val = X_t[:n_train], X_t[n_train:]
|
||||||
|
y_tr, y_val = y_t[:n_train], y_t[n_train:]
|
||||||
|
|
||||||
|
# DataLoaders
|
||||||
|
train_ds = TensorDataset(X_tr, y_tr)
|
||||||
|
val_ds = TensorDataset(X_val, y_val)
|
||||||
|
train_dl = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False)
|
||||||
|
val_dl = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
|
||||||
|
|
||||||
|
# Optimiseur et critère
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
|
||||||
|
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_count = 0
|
||||||
|
best_state = None
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
# --- Phase entraînement ---
|
||||||
|
model.train()
|
||||||
|
train_loss = 0.0
|
||||||
|
for X_batch, y_batch in train_dl:
|
||||||
|
X_batch = X_batch.to(device)
|
||||||
|
y_batch = y_batch.to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
logits = model(X_batch)
|
||||||
|
loss = criterion(logits, y_batch)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
train_loss += loss.item() * len(X_batch)
|
||||||
|
train_loss /= max(len(train_ds), 1)
|
||||||
|
|
||||||
|
# --- Phase validation ---
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for X_batch, y_batch in val_dl:
|
||||||
|
X_batch = X_batch.to(device)
|
||||||
|
y_batch = y_batch.to(device)
|
||||||
|
logits = model(X_batch)
|
||||||
|
loss = criterion(logits, y_batch)
|
||||||
|
val_loss += loss.item() * len(X_batch)
|
||||||
|
val_loss /= max(len(val_ds), 1)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f" Epoch {epoch+1}/{self.epochs} — "
|
||||||
|
f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
patience_count = 0
|
||||||
|
best_state = {
|
||||||
|
k: v.cpu().clone() for k, v in model.state_dict().items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
patience_count += 1
|
||||||
|
if patience_count >= self.patience:
|
||||||
|
logger.info(
|
||||||
|
f" Early stopping à l'époque {epoch+1} "
|
||||||
|
f"(patience={self.patience}, best_val_loss={best_val_loss:.4f})"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Restauration des meilleurs poids
|
||||||
|
if best_state is not None:
|
||||||
|
model.load_state_dict(best_state)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.to('cpu')
|
||||||
|
|
||||||
|
def _walk_forward_eval(
|
||||||
|
self,
|
||||||
|
X: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
class_weights: 'torch.Tensor',
|
||||||
|
n_splits: int = 2,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Évalue le modèle en cross-validation temporelle (walk-forward).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: Images (N, 3, 128, 128)
|
||||||
|
y: Labels encodés (N,)
|
||||||
|
class_weights: Poids de classes pour CrossEntropyLoss
|
||||||
|
n_splits: Nombre de folds temporels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict : avg_accuracy, avg_precision, fold_accuracies
|
||||||
|
"""
|
||||||
|
tscv = TimeSeriesSplit(n_splits=n_splits)
|
||||||
|
accuracies = []
|
||||||
|
precisions = []
|
||||||
|
|
||||||
|
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
|
||||||
|
X_tr, X_te = X[train_idx], X[test_idx]
|
||||||
|
y_tr, y_te = y[train_idx], y[test_idx]
|
||||||
|
|
||||||
|
# Modèle temporaire pour ce fold
|
||||||
|
fold_model = CandlestickCNN(n_classes=3)
|
||||||
|
self._train_model(fold_model, X_tr, y_tr, class_weights)
|
||||||
|
|
||||||
|
# Évaluation
|
||||||
|
x_tensor = torch.from_numpy(X_te).float()
|
||||||
|
proba_arr = fold_model.predict_proba(x_tensor) # (N_te, 3)
|
||||||
|
y_pred = np.argmax(proba_arr, axis=1)
|
||||||
|
|
||||||
|
acc = float((y_pred == y_te).mean())
|
||||||
|
accuracies.append(acc)
|
||||||
|
|
||||||
|
# Précision sur signaux directionnels (LONG=2 et SHORT=0, exclure NEUTRAL=1)
|
||||||
|
mask = (y_te != 1) | (y_pred != 1)
|
||||||
|
if mask.sum() > 0 and SKLEARN_AVAILABLE:
|
||||||
|
prec, _, _, _ = precision_recall_fscore_support(
|
||||||
|
y_te[mask], y_pred[mask],
|
||||||
|
average='macro', zero_division=0
|
||||||
|
)
|
||||||
|
precisions.append(float(prec))
|
||||||
|
else:
|
||||||
|
precisions.append(0.0)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" Fold {fold+1}/{n_splits} : acc={acc:.2%}, "
|
||||||
|
f"prec={precisions[-1]:.2%}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'avg_accuracy': float(np.mean(accuracies)) if accuracies else 0.0,
|
||||||
|
'avg_precision': float(np.mean(precisions)) if precisions else 0.0,
|
||||||
|
'fold_accuracies': [float(a) for a in accuracies],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor':
|
||||||
|
"""
|
||||||
|
Calcule les poids de classes inversement proportionnels à leur fréquence.
|
||||||
|
|
||||||
|
Permet de compenser les déséquilibres (ex: beaucoup de NEUTRAL, peu de trades).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Labels encodés {0, 1, 2}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor de forme (3,) — poids pour CrossEntropyLoss
|
||||||
|
"""
|
||||||
|
weights = np.ones(3, dtype=np.float32)
|
||||||
|
n_total = len(y)
|
||||||
|
|
||||||
|
if n_total == 0:
|
||||||
|
return torch.from_numpy(weights)
|
||||||
|
|
||||||
|
for cls in range(3):
|
||||||
|
count = (y == cls).sum()
|
||||||
|
if count > 0:
|
||||||
|
# Poids inversement proportionnel : classes rares → poids plus élevé
|
||||||
|
weights[cls] = n_total / (3.0 * count)
|
||||||
|
|
||||||
|
# Normalisation pour que la somme des poids = 3
|
||||||
|
weights = weights / weights.mean()
|
||||||
|
|
||||||
|
return torch.from_numpy(weights)
|
||||||
3
src/ml/ensemble/__init__.py
Normal file
3
src/ml/ensemble/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .ensemble_model import EnsembleModel
|
||||||
|
|
||||||
|
__all__ = ['EnsembleModel']
|
||||||
262
src/ml/ensemble/ensemble_model.py
Normal file
262
src/ml/ensemble/ensemble_model.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
Ensemble Model — Combine plusieurs modèles ML pour un signal de trading robuste.
|
||||||
|
|
||||||
|
L'EnsembleModel agrège les prédictions de modèles indépendants (XGBoost, CNN,
|
||||||
|
et plus tard RL) via une moyenne pondérée. Un signal n'est émis que si les
|
||||||
|
modèles actifs sont en accord ET que le score pondéré dépasse un seuil.
|
||||||
|
|
||||||
|
Duck typing : ce module n'importe PAS directement MLStrategyModel ni
|
||||||
|
CNNStrategyModel. Tout objet exposant `.predict(df)` → dict et `.is_trained`
|
||||||
|
→ bool est compatible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleModel:
|
||||||
|
"""
|
||||||
|
Combine plusieurs modèles ML pour produire un signal de trading robuste.
|
||||||
|
|
||||||
|
Logique :
|
||||||
|
- Chaque modèle prédit indépendamment (signal + confidence)
|
||||||
|
- Score final = somme pondérée des confidences pour les modèles en accord
|
||||||
|
- Signal validé uniquement si :
|
||||||
|
1. Au moins 2 modèles actifs sont en accord sur la direction
|
||||||
|
2. Score pondéré >= min_confidence
|
||||||
|
|
||||||
|
Poids par défaut : xgboost=0.40, cnn=0.60 (CNN légèrement favorisé car
|
||||||
|
il voit les données brutes sans biais de feature engineering)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_WEIGHTS = {
|
||||||
|
'xgboost': 0.30,
|
||||||
|
'cnn': 0.30,
|
||||||
|
'cnn_image': 0.40, # CNN Vision — favorisé (patterns visuels bruts)
|
||||||
|
'rl': 0.00, # Réservé Phase 4d
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights: Optional[Dict[str, float]] = None,
|
||||||
|
min_confidence: float = 0.60,
|
||||||
|
require_agreement: bool = True,
|
||||||
|
):
|
||||||
|
self.weights = dict(weights) if weights else dict(self.DEFAULT_WEIGHTS)
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
self.require_agreement = require_agreement
|
||||||
|
|
||||||
|
# Modèles attachés (duck typing : .predict(df), .is_trained)
|
||||||
|
self._models: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Référence directe au modèle CNN Image (accès rapide pour auto-attach)
|
||||||
|
self._cnn_image_model = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"EnsembleModel initialisé — poids={self.weights}, "
|
||||||
|
f"seuil={self.min_confidence}, accord_requis={self.require_agreement}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Attachement des modèles
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def attach_xgboost(self, model) -> None:
|
||||||
|
"""Attache un MLStrategyModel entraîné."""
|
||||||
|
self._attach('xgboost', model)
|
||||||
|
|
||||||
|
def attach_cnn(self, model) -> None:
|
||||||
|
"""Attache un CNNStrategyModel entraîné."""
|
||||||
|
self._attach('cnn', model)
|
||||||
|
|
||||||
|
def attach_cnn_image(self, model) -> None:
|
||||||
|
"""Attache un CNNImageStrategyModel entraîné (Phase 4c-bis)."""
|
||||||
|
self._cnn_image_model = model
|
||||||
|
self._attach('cnn_image', model)
|
||||||
|
|
||||||
|
def attach_rl(self, model) -> None:
|
||||||
|
"""Attache un agent RL (Phase 4d)."""
|
||||||
|
self._attach('rl', model)
|
||||||
|
|
||||||
|
def _attach(self, name: str, model) -> None:
|
||||||
|
"""Attache un modèle générique avec vérification duck typing."""
|
||||||
|
if not hasattr(model, 'predict') or not callable(model.predict):
|
||||||
|
raise ValueError(f"Le modèle '{name}' doit exposer une méthode predict()")
|
||||||
|
if not hasattr(model, 'is_trained'):
|
||||||
|
raise ValueError(f"Le modèle '{name}' doit exposer un attribut is_trained")
|
||||||
|
self._models[name] = model
|
||||||
|
# Ajouter le poids par défaut s'il n'existe pas
|
||||||
|
if name not in self.weights:
|
||||||
|
self.weights[name] = 0.0
|
||||||
|
logger.warning(f"Poids pour '{name}' non défini — initialisé à 0.0")
|
||||||
|
logger.info(f"Modèle '{name}' attaché (is_trained={model.is_trained})")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Prédiction combinée
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def predict(self, df: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Prédit le signal combiné à partir de tous les modèles actifs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
'signal': int, # 1 LONG, -1 SHORT, 0 NEUTRAL
|
||||||
|
'confidence': float, # score pondéré [0..1]
|
||||||
|
'tradeable': bool,
|
||||||
|
'agreement': bool, # True si tous les modèles actifs concordent
|
||||||
|
'components': dict, # résultats individuels par modèle
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
components: Dict[str, Dict] = {}
|
||||||
|
neutral_result = {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'agreement': False, 'components': components,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. Collecter les prédictions des modèles disponibles et entraînés
|
||||||
|
for name, model in self._models.items():
|
||||||
|
if not model.is_trained:
|
||||||
|
logger.debug(f"Ensemble : modèle '{name}' non entraîné, ignoré")
|
||||||
|
continue
|
||||||
|
if self.weights.get(name, 0.0) <= 0.0:
|
||||||
|
logger.debug(f"Ensemble : modèle '{name}' poids=0, ignoré")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result = model.predict(df)
|
||||||
|
components[name] = {
|
||||||
|
'signal': result.get('signal', 0),
|
||||||
|
'confidence': result.get('confidence', 0.0),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Ensemble : erreur predict '{name}' — {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not components:
|
||||||
|
logger.debug("Ensemble : aucun modèle actif n'a produit de prédiction")
|
||||||
|
neutral_result['components'] = components
|
||||||
|
return neutral_result
|
||||||
|
|
||||||
|
# 2. Filtrer les signaux non-neutres
|
||||||
|
directional = {
|
||||||
|
k: v for k, v in components.items() if v['signal'] != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if not directional:
|
||||||
|
# Tous les modèles sont neutres
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'agreement': True, 'components': components,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. Vérifier l'accord entre modèles directionnels
|
||||||
|
directions = set(v['signal'] for v in directional.values())
|
||||||
|
agreement = len(directions) == 1
|
||||||
|
|
||||||
|
if self.require_agreement and not agreement:
|
||||||
|
logger.debug(
|
||||||
|
f"Ensemble : désaccord entre modèles — {directional}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'agreement': False, 'components': components,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. Vérifier qu'au moins 2 modèles actifs sont en accord
|
||||||
|
if len(directional) < 2:
|
||||||
|
logger.debug("Ensemble : un seul modèle directionnel, signal insuffisant")
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'agreement': True, 'components': components,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. Calculer le score pondéré (normalisé sur les modèles actifs)
|
||||||
|
consensus_dir = directions.pop() # direction unique
|
||||||
|
total_weight = sum(self.weights.get(k, 0.0) for k in directional)
|
||||||
|
|
||||||
|
if total_weight <= 0:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'agreement': agreement, 'components': components,
|
||||||
|
}
|
||||||
|
|
||||||
|
weighted_score = sum(
|
||||||
|
self.weights.get(k, 0.0) * v['confidence']
|
||||||
|
for k, v in directional.items()
|
||||||
|
) / total_weight
|
||||||
|
|
||||||
|
# 6. Signal final
|
||||||
|
tradeable = weighted_score >= self.min_confidence
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ensemble : direction={'LONG' if consensus_dir == 1 else 'SHORT'} | "
|
||||||
|
f"score={weighted_score:.2%} | accord={agreement} | tradeable={tradeable}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'signal': consensus_dir,
|
||||||
|
'confidence': weighted_score,
|
||||||
|
'tradeable': tradeable,
|
||||||
|
'agreement': agreement,
|
||||||
|
'components': components,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Statut et configuration
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""True si au moins 2 modèles parmi {xgboost, cnn, cnn_image} sont attachés et entraînés."""
|
||||||
|
eligible_names = {'xgboost', 'cnn', 'cnn_image'}
|
||||||
|
trained = sum(
|
||||||
|
1 for name, model in self._models.items()
|
||||||
|
if name in eligible_names
|
||||||
|
and model.is_trained
|
||||||
|
and self.weights.get(name, 0.0) > 0
|
||||||
|
)
|
||||||
|
return trained >= 2
|
||||||
|
|
||||||
|
def get_status(self) -> Dict:
|
||||||
|
"""Statut de chaque composant + poids actifs."""
|
||||||
|
status = {
|
||||||
|
'ready': self.is_ready(),
|
||||||
|
'min_confidence': self.min_confidence,
|
||||||
|
'require_agreement': self.require_agreement,
|
||||||
|
'weights': dict(self.weights),
|
||||||
|
'models': {},
|
||||||
|
}
|
||||||
|
for name, model in self._models.items():
|
||||||
|
status['models'][name] = {
|
||||||
|
'attached': True,
|
||||||
|
'is_trained': model.is_trained,
|
||||||
|
'weight': self.weights.get(name, 0.0),
|
||||||
|
}
|
||||||
|
# Modèles non attachés mais présents dans les poids
|
||||||
|
for name in self.weights:
|
||||||
|
if name not in status['models']:
|
||||||
|
status['models'][name] = {
|
||||||
|
'attached': False,
|
||||||
|
'is_trained': False,
|
||||||
|
'weight': self.weights[name],
|
||||||
|
}
|
||||||
|
return status
|
||||||
|
|
||||||
|
def update_weights(self, weights: Dict[str, float]) -> None:
|
||||||
|
"""
|
||||||
|
Mise à jour dynamique des poids.
|
||||||
|
|
||||||
|
Si la somme != 1.0, normalise automatiquement et log un warning.
|
||||||
|
"""
|
||||||
|
total = sum(weights.values())
|
||||||
|
if total <= 0:
|
||||||
|
raise ValueError("La somme des poids doit être > 0")
|
||||||
|
|
||||||
|
if abs(total - 1.0) > 1e-6:
|
||||||
|
logger.warning(
|
||||||
|
f"Somme des poids = {total:.4f} != 1.0 — normalisation automatique"
|
||||||
|
)
|
||||||
|
weights = {k: v / total for k, v in weights.items()}
|
||||||
|
|
||||||
|
self.weights.update(weights)
|
||||||
|
logger.info(f"Poids mis à jour : {self.weights}")
|
||||||
@@ -156,19 +156,23 @@ class MLStrategyModel:
|
|||||||
logger.info(f" {len(X)} échantillons, {len(self.feature_names)} features")
|
logger.info(f" {len(X)} échantillons, {len(self.feature_names)} features")
|
||||||
logger.info(f" Distribution : LONG={( y==1).sum()}, SHORT={(y==-1).sum()}, NEUTRAL={(y==0).sum()}")
|
logger.info(f" Distribution : LONG={( y==1).sum()}, SHORT={(y==-1).sum()}, NEUTRAL={(y==0).sum()}")
|
||||||
|
|
||||||
|
# Encodage labels [-1,0,1] → [0,1,2] (requis par XGBoost ≥ 2.x)
|
||||||
|
y_enc = y + 1
|
||||||
|
|
||||||
# 3. Walk-forward cross-validation (3 folds temporels)
|
# 3. Walk-forward cross-validation (3 folds temporels)
|
||||||
wf_metrics = self._walk_forward_eval(X, y, n_splits=3)
|
wf_metrics = self._walk_forward_eval(X, y_enc, n_splits=3)
|
||||||
|
|
||||||
# 4. Entraînement sur la totalité des données
|
# 4. Entraînement sur la totalité des données
|
||||||
self.scaler = StandardScaler()
|
self.scaler = StandardScaler()
|
||||||
X_scaled = self.scaler.fit_transform(X)
|
X_scaled = self.scaler.fit_transform(X)
|
||||||
|
|
||||||
self.model = self._build_model()
|
self.model = self._build_model()
|
||||||
self.model.fit(X_scaled, y)
|
self.model.fit(X_scaled, y_enc)
|
||||||
self.is_trained = True
|
self.is_trained = True
|
||||||
|
|
||||||
# 5. Évaluation finale (in-sample — indicative)
|
# 5. Évaluation finale (in-sample — indicative)
|
||||||
y_pred = self.model.predict(X_scaled)
|
y_pred_enc = self.model.predict(X_scaled)
|
||||||
|
y_pred = y_pred_enc - 1 # décodage [0,1,2] → [-1,0,1]
|
||||||
report = classification_report(y, y_pred, labels=[-1, 0, 1],
|
report = classification_report(y, y_pred, labels=[-1, 0, 1],
|
||||||
target_names=['SHORT', 'NEUTRAL', 'LONG'],
|
target_names=['SHORT', 'NEUTRAL', 'LONG'],
|
||||||
output_dict=True, zero_division=0)
|
output_dict=True, zero_division=0)
|
||||||
@@ -235,24 +239,25 @@ class MLStrategyModel:
|
|||||||
last = last[self.feature_names].fillna(0)
|
last = last[self.feature_names].fillna(0)
|
||||||
|
|
||||||
X_scaled = self.scaler.transform(last)
|
X_scaled = self.scaler.transform(last)
|
||||||
pred = self.model.predict(X_scaled)[0]
|
pred_enc = self.model.predict(X_scaled)[0]
|
||||||
|
pred = int(pred_enc) - 1 # décodage [0,1,2] → [-1,0,1]
|
||||||
|
|
||||||
# Probabilités si disponibles
|
# Probabilités si disponibles
|
||||||
probas = {'long': 0.0, 'short': 0.0, 'neutral': 1.0}
|
probas = {'long': 0.0, 'short': 0.0, 'neutral': 1.0}
|
||||||
confidence = 0.0
|
confidence = 0.0
|
||||||
if hasattr(self.model, 'predict_proba'):
|
if hasattr(self.model, 'predict_proba'):
|
||||||
proba_arr = self.model.predict_proba(X_scaled)[0]
|
proba_arr = self.model.predict_proba(X_scaled)[0]
|
||||||
classes = list(self.model.classes_)
|
classes = list(self.model.classes_) # [0, 1, 2] encodés
|
||||||
prob_map = {c: p for c, p in zip(classes, proba_arr)}
|
prob_map = {c: p for c, p in zip(classes, proba_arr)}
|
||||||
probas = {
|
probas = {
|
||||||
'long': float(prob_map.get(1, 0.0)),
|
'long': float(prob_map.get(2, 0.0)), # encodé 2 = LONG (1)
|
||||||
'short': float(prob_map.get(-1, 0.0)),
|
'short': float(prob_map.get(0, 0.0)), # encodé 0 = SHORT (-1)
|
||||||
'neutral': float(prob_map.get(0, 1.0)),
|
'neutral': float(prob_map.get(1, 1.0)), # encodé 1 = NEUTRAL (0)
|
||||||
}
|
}
|
||||||
confidence = float(max(probas['long'], probas['short']))
|
confidence = float(max(probas['long'], probas['short']))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'signal': int(pred),
|
'signal': pred,
|
||||||
'confidence': confidence,
|
'confidence': confidence,
|
||||||
'probas': probas,
|
'probas': probas,
|
||||||
'tradeable': confidence >= self.min_confidence and pred != 0,
|
'tradeable': confidence >= self.min_confidence and pred != 0,
|
||||||
@@ -383,7 +388,8 @@ class MLStrategyModel:
|
|||||||
|
|
||||||
acc = (y_pred == y_te.values).mean()
|
acc = (y_pred == y_te.values).mean()
|
||||||
# Précision/Recall sur les signaux directionnels uniquement
|
# Précision/Recall sur les signaux directionnels uniquement
|
||||||
mask = (y_te != 0) | (y_pred != 0)
|
# y encodé : 0=SHORT, 1=NEUTRAL, 2=LONG → NEUTRAL=1
|
||||||
|
mask = (y_te != 1) | (y_pred != 1)
|
||||||
prec, rec, _, _ = precision_recall_fscore_support(
|
prec, rec, _, _ = precision_recall_fscore_support(
|
||||||
y_te[mask], y_pred[mask], average='macro', zero_division=0
|
y_te[mask], y_pred[mask], average='macro', zero_division=0
|
||||||
) if mask.sum() > 0 else (0, 0, 0, 0)
|
) if mask.sum() > 0 else (0, 0, 0, 0)
|
||||||
|
|||||||
@@ -11,10 +11,12 @@ les différents régimes de marché:
|
|||||||
Permet d'adapter les stratégies selon le régime actuel.
|
Permet d'adapter les stratégies selon le régime actuel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -24,8 +26,18 @@ except ImportError:
|
|||||||
HMMLEARN_AVAILABLE = False
|
HMMLEARN_AVAILABLE = False
|
||||||
logging.warning("hmmlearn not installed. Install with: pip install hmmlearn")
|
logging.warning("hmmlearn not installed. Install with: pip install hmmlearn")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
JOBLIB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
JOBLIB_AVAILABLE = False
|
||||||
|
logging.warning("joblib not installed. La persistance HMM sera désactivée.")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Répertoire de persistance des modèles HMM
|
||||||
|
MODELS_DIR = Path(__file__).parent.parent.parent / "models" / "hmm"
|
||||||
|
|
||||||
|
|
||||||
class RegimeDetector:
|
class RegimeDetector:
|
||||||
"""
|
"""
|
||||||
@@ -79,6 +91,9 @@ class RegimeDetector:
|
|||||||
|
|
||||||
self.is_fitted = False
|
self.is_fitted = False
|
||||||
self.feature_names = []
|
self.feature_names = []
|
||||||
|
# Métadonnées d'entraînement pour la persistance
|
||||||
|
self._trained_at: Optional[datetime] = None
|
||||||
|
self._n_samples: int = 0
|
||||||
|
|
||||||
logger.info(f"RegimeDetector initialized with {n_regimes} regimes")
|
logger.info(f"RegimeDetector initialized with {n_regimes} regimes")
|
||||||
|
|
||||||
@@ -115,11 +130,147 @@ class RegimeDetector:
|
|||||||
try:
|
try:
|
||||||
self.model.fit(X)
|
self.model.fit(X)
|
||||||
self.is_fitted = True
|
self.is_fitted = True
|
||||||
logger.info("✅ HMM model fitted successfully")
|
self._trained_at = datetime.now()
|
||||||
|
self._n_samples = len(X)
|
||||||
|
logger.info("Modèle HMM entraîné avec succès")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fitting HMM: {e}")
|
logger.error(f"Erreur lors de l'entraînement HMM : {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_trained(self) -> bool:
|
||||||
|
"""True si le modèle HMM a été ajusté (fit)."""
|
||||||
|
return self.is_fitted
|
||||||
|
|
||||||
|
def needs_retrain(self, max_age_hours: int = 24) -> bool:
|
||||||
|
"""
|
||||||
|
Indique si le modèle doit être ré-entraîné.
|
||||||
|
|
||||||
|
Un ré-entraînement est nécessaire si :
|
||||||
|
- Le modèle n'a jamais été entraîné
|
||||||
|
- La date d'entraînement est inconnue
|
||||||
|
- Le modèle est plus vieux que max_age_hours
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_hours: Âge maximum du modèle en heures (défaut 24h)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True si un ré-entraînement est nécessaire
|
||||||
|
"""
|
||||||
|
if not self.is_fitted or self._trained_at is None:
|
||||||
|
return True
|
||||||
|
age = datetime.now() - self._trained_at
|
||||||
|
return age > timedelta(hours=max_age_hours)
|
||||||
|
|
||||||
|
def save(self, symbol: str, timeframe: str) -> bool:
|
||||||
|
"""
|
||||||
|
Sauvegarde le modèle HMM entraîné sur disque avec joblib.
|
||||||
|
|
||||||
|
Le modèle est sauvegardé dans :
|
||||||
|
models/hmm/{symbol}_{timeframe}.joblib
|
||||||
|
|
||||||
|
Les métadonnées (date, n_samples, n_components, labels) sont
|
||||||
|
stockées dans un fichier JSON compagnon :
|
||||||
|
models/hmm/{symbol}_{timeframe}_meta.json
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Symbole de l'instrument (ex : "EURUSD")
|
||||||
|
timeframe: Unité de temps (ex : "1h")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True si la sauvegarde a réussi, False sinon
|
||||||
|
"""
|
||||||
|
if not self.is_fitted:
|
||||||
|
logger.warning("Impossible de sauvegarder : modèle non entraîné")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not JOBLIB_AVAILABLE:
|
||||||
|
logger.warning("joblib indisponible — sauvegarde HMM ignorée")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Créer le répertoire si nécessaire
|
||||||
|
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
base = f"{symbol}_{timeframe}"
|
||||||
|
model_path = MODELS_DIR / f"{base}.joblib"
|
||||||
|
meta_path = MODELS_DIR / f"{base}_meta.json"
|
||||||
|
|
||||||
|
# Sauvegarder le modèle HMM + feature_names
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"feature_names": self.feature_names,
|
||||||
|
"n_regimes": self.n_regimes,
|
||||||
|
"random_state": self.random_state,
|
||||||
|
}
|
||||||
|
joblib.dump(payload, model_path)
|
||||||
|
|
||||||
|
# Sauvegarder les métadonnées en JSON
|
||||||
|
meta = {
|
||||||
|
"trained_at": self._trained_at.isoformat() if self._trained_at else None,
|
||||||
|
"n_samples": self._n_samples,
|
||||||
|
"n_components": self.n_regimes,
|
||||||
|
"regime_labels": self.REGIME_NAMES,
|
||||||
|
"symbol": symbol,
|
||||||
|
"timeframe": timeframe,
|
||||||
|
}
|
||||||
|
meta_path.write_text(json.dumps(meta, indent=2, default=str))
|
||||||
|
|
||||||
|
logger.info(f"Modèle HMM sauvegardé : {model_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Erreur lors de la sauvegarde du modèle HMM : {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load(self, symbol: str, timeframe: str) -> bool:
|
||||||
|
"""
|
||||||
|
Charge un modèle HMM depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Symbole de l'instrument (ex : "EURUSD")
|
||||||
|
timeframe: Unité de temps (ex : "1h")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True si le chargement a réussi, False sinon
|
||||||
|
"""
|
||||||
|
if not JOBLIB_AVAILABLE:
|
||||||
|
logger.warning("joblib indisponible — chargement HMM impossible")
|
||||||
|
return False
|
||||||
|
|
||||||
|
base = f"{symbol}_{timeframe}"
|
||||||
|
model_path = MODELS_DIR / f"{base}.joblib"
|
||||||
|
meta_path = MODELS_DIR / f"{base}_meta.json"
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
logger.debug(f"Aucun modèle HMM trouvé pour {symbol}/{timeframe}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = joblib.load(model_path)
|
||||||
|
self.model = payload["model"]
|
||||||
|
self.feature_names = payload["feature_names"]
|
||||||
|
self.n_regimes = payload["n_regimes"]
|
||||||
|
self.random_state = payload["random_state"]
|
||||||
|
self.is_fitted = True
|
||||||
|
|
||||||
|
# Charger les métadonnées si disponibles
|
||||||
|
if meta_path.exists():
|
||||||
|
meta = json.loads(meta_path.read_text())
|
||||||
|
trained_at_raw = meta.get("trained_at")
|
||||||
|
self._trained_at = (
|
||||||
|
datetime.fromisoformat(trained_at_raw) if trained_at_raw else None
|
||||||
|
)
|
||||||
|
self._n_samples = meta.get("n_samples", 0)
|
||||||
|
|
||||||
|
logger.info(f"Modèle HMM chargé depuis {model_path} (entraîné le {self._trained_at})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Erreur lors du chargement du modèle HMM : {exc}")
|
||||||
|
self.is_fitted = False
|
||||||
|
return False
|
||||||
|
|
||||||
def predict_regime(self, data: pd.DataFrame) -> np.ndarray:
|
def predict_regime(self, data: pd.DataFrame) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Prédit les régimes pour toutes les barres.
|
Prédit les régimes pour toutes les barres.
|
||||||
|
|||||||
24
src/ml/rl/__init__.py
Normal file
24
src/ml/rl/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Module RL (Reinforcement Learning) — Agent PPO pour le trading algorithmique.
|
||||||
|
|
||||||
|
Ce module implémente un agent PPO (Proximal Policy Optimization) entraîné
|
||||||
|
par renforcement sur un environnement de trading simulé.
|
||||||
|
|
||||||
|
Composants :
|
||||||
|
TradingEnv — Environnement gymnasium conforme (observation 20 features)
|
||||||
|
PPOModel — Réseau Actor-Critic MLP 256→128→64 + entraînement PPO
|
||||||
|
RLStrategyModel — Interface unifiée (identique à MLStrategyModel)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE
|
||||||
|
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE
|
||||||
|
from src.ml.rl.rl_strategy_model import RLStrategyModel, RL_AVAILABLE
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'TradingEnv',
|
||||||
|
'PPOModel',
|
||||||
|
'RLStrategyModel',
|
||||||
|
'RL_AVAILABLE',
|
||||||
|
'TORCH_AVAILABLE',
|
||||||
|
'GYM_AVAILABLE',
|
||||||
|
]
|
||||||
681
src/ml/rl/ppo_model.py
Normal file
681
src/ml/rl/ppo_model.py
Normal file
@@ -0,0 +1,681 @@
|
|||||||
|
"""
|
||||||
|
PPOModel — Implémentation manuelle de l'algorithme PPO (Proximal Policy Optimization).
|
||||||
|
|
||||||
|
Architecture Actor-Critic :
|
||||||
|
- Réseau partagé : MLP 3 couches (256 → 128 → 64) avec BatchNorm + ReLU
|
||||||
|
- Tête Actor : couche linéaire → logits (n_actions=3)
|
||||||
|
- Tête Critic : couche linéaire → valeur scalaire V(s)
|
||||||
|
|
||||||
|
Hyperparamètres PPO :
|
||||||
|
- clip ε=0.2 — clip ratio de probabilité pour éviter les grandes mises à jour
|
||||||
|
- entropy_coef=0.01 — bonus d'entropie pour exploration
|
||||||
|
- value_coef=0.5 — coefficient de la loss valeur
|
||||||
|
- n_steps=2048 — nombre de transitions collectées avant mise à jour
|
||||||
|
- n_epochs=10 — passes sur les données collectées
|
||||||
|
- batch_size=64 — taille des mini-batchs
|
||||||
|
- gamma=0.99 — facteur d'actualisation
|
||||||
|
- gae_lambda=0.95 — λ pour Generalized Advantage Estimation
|
||||||
|
|
||||||
|
Sauvegarde :
|
||||||
|
- models/rl_strategy/{symbol}_{timeframe}.pt (state_dict + config)
|
||||||
|
- models/rl_strategy/{symbol}_{timeframe}_meta.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Répertoire de sauvegarde des modèles RL
|
||||||
|
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "rl_strategy"
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Import conditionnel PyTorch
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
nn = None
|
||||||
|
optim = None
|
||||||
|
Categorical = None
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Réseau Actor-Critic
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
if TORCH_AVAILABLE:
|
||||||
|
class ActorCriticNetwork(nn.Module):
|
||||||
|
"""
|
||||||
|
Réseau Actor-Critic partagé pour PPO.
|
||||||
|
|
||||||
|
Architecture :
|
||||||
|
- Tronc commun : Linear(obs_dim → 256) → BN → ReLU
|
||||||
|
Linear(256 → 128) → BN → ReLU
|
||||||
|
Linear(128 → 64) → ReLU
|
||||||
|
- Tête Actor : Linear(64 → n_actions)
|
||||||
|
- Tête Critic : Linear(64 → 1)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_dim: Dimension de l'espace d'observation
|
||||||
|
n_actions: Nombre d'actions discrètes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, obs_dim: int = 20, n_actions: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Tronc commun
|
||||||
|
self.trunk = nn.Sequential(
|
||||||
|
nn.Linear(obs_dim, 256),
|
||||||
|
nn.BatchNorm1d(256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.BatchNorm1d(128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tête Actor : logits pour chaque action
|
||||||
|
self.actor_head = nn.Linear(64, n_actions)
|
||||||
|
|
||||||
|
# Tête Critic : estimation de la valeur d'état V(s)
|
||||||
|
self.critic_head = nn.Linear(64, 1)
|
||||||
|
|
||||||
|
# Initialisation des poids (orthogonale pour stabilité PPO)
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
"""Initialisation orthogonale recommandée pour les réseaux PPO."""
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
||||||
|
nn.init.constant_(module.bias, 0.0)
|
||||||
|
# Gain plus faible pour les têtes finales
|
||||||
|
nn.init.orthogonal_(self.actor_head.weight, gain=0.01)
|
||||||
|
nn.init.orthogonal_(self.critic_head.weight, gain=1.0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: 'torch.Tensor'
|
||||||
|
) -> Tuple['torch.Tensor', 'torch.Tensor']:
|
||||||
|
"""
|
||||||
|
Calcule les logits actor et la valeur critic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor d'observations (batch_size, obs_dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (logits_actor, value_critic)
|
||||||
|
logits_actor : (batch_size, n_actions)
|
||||||
|
value_critic : (batch_size, 1)
|
||||||
|
"""
|
||||||
|
features = self.trunk(x)
|
||||||
|
logits = self.actor_head(features)
|
||||||
|
value = self.critic_head(features)
|
||||||
|
return logits, value
|
||||||
|
|
||||||
|
def get_action_and_value(
|
||||||
|
self, x: 'torch.Tensor', action: Optional['torch.Tensor'] = None
|
||||||
|
) -> Tuple['torch.Tensor', 'torch.Tensor', 'torch.Tensor', 'torch.Tensor']:
|
||||||
|
"""
|
||||||
|
Retourne l'action, son log-prob, l'entropie et la valeur.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Observations (batch_size, obs_dim)
|
||||||
|
action: Si fourni, calcule le log-prob de cet action existante
|
||||||
|
(pour la phase de mise à jour PPO)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (action, log_prob, entropy, value)
|
||||||
|
"""
|
||||||
|
logits, value = self.forward(x)
|
||||||
|
dist = Categorical(logits=logits)
|
||||||
|
|
||||||
|
if action is None:
|
||||||
|
action = dist.sample()
|
||||||
|
|
||||||
|
log_prob = dist.log_prob(action)
|
||||||
|
entropy = dist.entropy()
|
||||||
|
|
||||||
|
return action, log_prob, entropy, value.squeeze(-1)
|
||||||
|
|
||||||
|
def get_value(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||||
|
"""Retourne uniquement la valeur critic V(s)."""
|
||||||
|
_, value = self.forward(x)
|
||||||
|
return value.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
class PPOModel:
|
||||||
|
"""
|
||||||
|
Agent PPO (Proximal Policy Optimization) pour le trading.
|
||||||
|
|
||||||
|
Implémentation manuelle sans stable-baselines3, utilisant PyTorch pur.
|
||||||
|
Suit le pseudo-code de Schulman et al. (2017) avec GAE.
|
||||||
|
|
||||||
|
Méthodes publiques :
|
||||||
|
train(env, total_timesteps) — Lance l'entraînement
|
||||||
|
predict(obs) — Retourne l'action et le log-prob
|
||||||
|
save(symbol, timeframe) — Sauvegarde le modèle
|
||||||
|
load(symbol, timeframe) — Charge un modèle existant (classmethod)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_dim: Dimension de l'observation (défaut: 20)
|
||||||
|
n_actions: Nombre d'actions discrètes (défaut: 3)
|
||||||
|
lr: Taux d'apprentissage Adam (défaut: 3e-4)
|
||||||
|
n_steps: Transitions par rollout avant mise à jour (défaut: 2048)
|
||||||
|
n_epochs: Passes d'optimisation par rollout (défaut: 10)
|
||||||
|
batch_size: Taille des mini-batchs (défaut: 64)
|
||||||
|
gamma: Facteur d'actualisation (défaut: 0.99)
|
||||||
|
gae_lambda: Lambda GAE (défaut: 0.95)
|
||||||
|
clip_eps: Epsilon de clip PPO (défaut: 0.2)
|
||||||
|
entropy_coef: Coefficient bonus entropie (défaut: 0.01)
|
||||||
|
value_coef: Coefficient loss valeur (défaut: 0.5)
|
||||||
|
max_grad_norm: Norme maximale du gradient (défaut: 0.5)
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS_DIR = MODELS_DIR
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obs_dim: int = 20,
|
||||||
|
n_actions: int = 3,
|
||||||
|
lr: float = 3e-4,
|
||||||
|
n_steps: int = 2048,
|
||||||
|
n_epochs: int = 10,
|
||||||
|
batch_size: int = 64,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
|
clip_eps: float = 0.2,
|
||||||
|
entropy_coef: float = 0.01,
|
||||||
|
value_coef: float = 0.5,
|
||||||
|
max_grad_norm: float = 0.5,
|
||||||
|
):
|
||||||
|
self.obs_dim = obs_dim
|
||||||
|
self.n_actions = n_actions
|
||||||
|
self.lr = lr
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.n_epochs = n_epochs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.gamma = gamma
|
||||||
|
self.gae_lambda = gae_lambda
|
||||||
|
self.clip_eps = clip_eps
|
||||||
|
self.entropy_coef = entropy_coef
|
||||||
|
self.value_coef = value_coef
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
self.network: Optional['ActorCriticNetwork'] = None
|
||||||
|
self.optimizer: Optional['optim.Adam'] = None
|
||||||
|
self.is_trained: bool = False
|
||||||
|
self.metadata: Dict = {}
|
||||||
|
|
||||||
|
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if TORCH_AVAILABLE:
|
||||||
|
self._init_network()
|
||||||
|
|
||||||
|
def _init_network(self) -> None:
|
||||||
|
"""Initialise le réseau Actor-Critic et l'optimiseur Adam."""
|
||||||
|
self.network = ActorCriticNetwork(self.obs_dim, self.n_actions)
|
||||||
|
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr, eps=1e-5)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Entraînement
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
total_timesteps: int = 100_000,
|
||||||
|
symbol: str = 'EURUSD',
|
||||||
|
timeframe: str = '1h',
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Entraîne l'agent PPO sur l'environnement de trading fourni.
|
||||||
|
|
||||||
|
Algorithme :
|
||||||
|
Boucle principale :
|
||||||
|
1. Collecte de n_steps transitions (rollout)
|
||||||
|
2. Calcul des avantages GAE
|
||||||
|
3. n_epochs passes d'optimisation sur mini-batchs mélangés
|
||||||
|
Fin :
|
||||||
|
Sauvegarde du modèle et retour des métriques
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: Environnement TradingEnv
|
||||||
|
total_timesteps: Nombre total de transitions à collecter
|
||||||
|
symbol: Symbole pour la sauvegarde (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe pour la sauvegarde (ex: '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict avec métriques : total_timesteps, n_updates, mean_reward,
|
||||||
|
mean_ep_return, policy_loss, value_loss, entropy_loss
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||||
|
|
||||||
|
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.network.to(device)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Début entraînement PPO {symbol}/{timeframe} — "
|
||||||
|
f"total_timesteps={total_timesteps}, n_steps={self.n_steps}, device={device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Buffers de rollout ──────────────────────────────────────────────
|
||||||
|
obs_buf = np.zeros((self.n_steps, self.obs_dim), dtype=np.float32)
|
||||||
|
actions_buf = np.zeros((self.n_steps,), dtype=np.int64)
|
||||||
|
rewards_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||||
|
dones_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||||
|
values_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||||
|
logprobs_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||||
|
|
||||||
|
# ── Statistiques d'entraînement ─────────────────────────────────────
|
||||||
|
all_episode_returns = []
|
||||||
|
all_policy_losses = []
|
||||||
|
all_value_losses = []
|
||||||
|
all_entropy_losses = []
|
||||||
|
ep_return = 0.0
|
||||||
|
n_updates = 0
|
||||||
|
|
||||||
|
obs, _ = env.reset()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
timestep = 0
|
||||||
|
rollout_idx = 0
|
||||||
|
|
||||||
|
while timestep < total_timesteps:
|
||||||
|
# ── Phase de collecte (rollout) ───────────────────────────────────
|
||||||
|
self.network.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
|
||||||
|
action, log_prob, _, value = self.network.get_action_and_value(obs_t)
|
||||||
|
|
||||||
|
action_np = int(action.cpu().item())
|
||||||
|
log_prob_np = float(log_prob.cpu().item())
|
||||||
|
value_np = float(value.cpu().item())
|
||||||
|
|
||||||
|
next_obs, reward, terminated, truncated, _ = env.step(action_np)
|
||||||
|
done = terminated or truncated
|
||||||
|
|
||||||
|
obs_buf[rollout_idx] = obs
|
||||||
|
actions_buf[rollout_idx] = action_np
|
||||||
|
rewards_buf[rollout_idx] = reward
|
||||||
|
dones_buf[rollout_idx] = float(done)
|
||||||
|
values_buf[rollout_idx] = value_np
|
||||||
|
logprobs_buf[rollout_idx] = log_prob_np
|
||||||
|
|
||||||
|
ep_return += reward
|
||||||
|
obs = next_obs
|
||||||
|
timestep += 1
|
||||||
|
rollout_idx += 1
|
||||||
|
|
||||||
|
if done:
|
||||||
|
all_episode_returns.append(ep_return)
|
||||||
|
ep_return = 0.0
|
||||||
|
obs, _ = env.reset()
|
||||||
|
|
||||||
|
# ── Phase de mise à jour PPO ──────────────────────────────────────
|
||||||
|
if rollout_idx == self.n_steps:
|
||||||
|
# Calcul de la valeur du dernier état (bootstrap)
|
||||||
|
with torch.no_grad():
|
||||||
|
last_obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
|
||||||
|
last_value = self.network.get_value(last_obs_t).cpu().numpy()
|
||||||
|
|
||||||
|
# Calcul des avantages GAE et des retours
|
||||||
|
advantages, returns = self._compute_gae(
|
||||||
|
rewards_buf, values_buf, dones_buf, float(last_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conversion en tensors
|
||||||
|
obs_t = torch.from_numpy(obs_buf).float().to(device)
|
||||||
|
actions_t = torch.from_numpy(actions_buf).long().to(device)
|
||||||
|
logprobs_t = torch.from_numpy(logprobs_buf).float().to(device)
|
||||||
|
advs_t = torch.from_numpy(advantages).float().to(device)
|
||||||
|
returns_t = torch.from_numpy(returns).float().to(device)
|
||||||
|
|
||||||
|
# Normalisation des avantages
|
||||||
|
advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
|
||||||
|
|
||||||
|
# n_epochs passes d'optimisation
|
||||||
|
n_samples = self.n_steps
|
||||||
|
policy_losses_ep = []
|
||||||
|
value_losses_ep = []
|
||||||
|
entropy_losses_ep = []
|
||||||
|
|
||||||
|
self.network.train()
|
||||||
|
for _ in range(self.n_epochs):
|
||||||
|
indices = np.random.permutation(n_samples)
|
||||||
|
|
||||||
|
for start in range(0, n_samples, self.batch_size):
|
||||||
|
end = start + self.batch_size
|
||||||
|
batch = indices[start:end]
|
||||||
|
if len(batch) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
_, new_logprob, entropy, new_value = (
|
||||||
|
self.network.get_action_and_value(
|
||||||
|
obs_t[batch], actions_t[batch]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ratio de probabilité
|
||||||
|
log_ratio = new_logprob - logprobs_t[batch]
|
||||||
|
ratio = torch.exp(log_ratio)
|
||||||
|
|
||||||
|
# Clip PPO
|
||||||
|
adv_batch = advs_t[batch]
|
||||||
|
pg_loss1 = -adv_batch * ratio
|
||||||
|
pg_loss2 = -adv_batch * torch.clamp(
|
||||||
|
ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps
|
||||||
|
)
|
||||||
|
policy_loss = torch.max(pg_loss1, pg_loss2).mean()
|
||||||
|
|
||||||
|
# Loss valeur (avec clip optionnel)
|
||||||
|
value_loss = nn.functional.mse_loss(new_value, returns_t[batch])
|
||||||
|
|
||||||
|
# Bonus d'entropie (encourage l'exploration)
|
||||||
|
entropy_loss = -entropy.mean()
|
||||||
|
|
||||||
|
# Loss totale
|
||||||
|
total_loss = (
|
||||||
|
policy_loss
|
||||||
|
+ self.value_coef * value_loss
|
||||||
|
+ self.entropy_coef * entropy_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimisation
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
total_loss.backward()
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
self.network.parameters(), self.max_grad_norm
|
||||||
|
)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
policy_losses_ep.append(float(policy_loss.item()))
|
||||||
|
value_losses_ep.append(float(value_loss.item()))
|
||||||
|
entropy_losses_ep.append(float(entropy_loss.item()))
|
||||||
|
|
||||||
|
all_policy_losses.extend(policy_losses_ep)
|
||||||
|
all_value_losses.extend(value_losses_ep)
|
||||||
|
all_entropy_losses.extend(entropy_losses_ep)
|
||||||
|
n_updates += 1
|
||||||
|
rollout_idx = 0
|
||||||
|
|
||||||
|
if n_updates % 10 == 0:
|
||||||
|
mean_ret = float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0
|
||||||
|
logger.info(
|
||||||
|
f" Timestep {timestep}/{total_timesteps} | "
|
||||||
|
f"Updates={n_updates} | MeanReturn(20ep)={mean_ret:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.network.eval()
|
||||||
|
self.network.to('cpu')
|
||||||
|
self.is_trained = True
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
'symbol': symbol,
|
||||||
|
'timeframe': timeframe,
|
||||||
|
'total_timesteps': total_timesteps,
|
||||||
|
'n_updates': n_updates,
|
||||||
|
'mean_reward': float(np.mean(all_episode_returns)) if all_episode_returns else 0.0,
|
||||||
|
'mean_ep_return': float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0,
|
||||||
|
'n_episodes': len(all_episode_returns),
|
||||||
|
'policy_loss': float(np.mean(all_policy_losses[-100:])) if all_policy_losses else 0.0,
|
||||||
|
'value_loss': float(np.mean(all_value_losses[-100:])) if all_value_losses else 0.0,
|
||||||
|
'entropy_loss': float(np.mean(all_entropy_losses[-100:])) if all_entropy_losses else 0.0,
|
||||||
|
'trained_at': datetime.utcnow().isoformat(),
|
||||||
|
'hyperparams': {
|
||||||
|
'lr': self.lr,
|
||||||
|
'n_steps': self.n_steps,
|
||||||
|
'n_epochs': self.n_epochs,
|
||||||
|
'batch_size': self.batch_size,
|
||||||
|
'gamma': self.gamma,
|
||||||
|
'gae_lambda': self.gae_lambda,
|
||||||
|
'clip_eps': self.clip_eps,
|
||||||
|
'entropy_coef': self.entropy_coef,
|
||||||
|
'value_coef': self.value_coef,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.metadata = metrics
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Entraînement PPO terminé. "
|
||||||
|
f"MeanReturn={metrics['mean_ep_return']:.4f} | "
|
||||||
|
f"Episodes={metrics['n_episodes']}"
|
||||||
|
)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def predict(self, obs: np.ndarray) -> Tuple[int, float]:
|
||||||
|
"""
|
||||||
|
Prédit l'action optimale pour une observation donnée.
|
||||||
|
|
||||||
|
En mode inférence (deterministic=True) : argmax des logits.
|
||||||
|
En mode exploration : sampling de la distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: Vecteur d'observation (obs_dim,) en float32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (action, log_prob) où action ∈ {0, 1, 2}
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||||
|
return 0, 0.0
|
||||||
|
|
||||||
|
self.network.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||||
|
action, log_prob, _, _ = self.network.get_action_and_value(obs_t)
|
||||||
|
|
||||||
|
return int(action.item()), float(log_prob.item())
|
||||||
|
|
||||||
|
def predict_deterministic(self, obs: np.ndarray) -> int:
|
||||||
|
"""
|
||||||
|
Prédit l'action déterministe (argmax) sans exploration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: Vecteur d'observation (obs_dim,) en float32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action déterministe (argmax des logits) ∈ {0, 1, 2}
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
self.network.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||||
|
logits, _ = self.network(obs_t)
|
||||||
|
action = int(logits.argmax(dim=-1).item())
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def get_action_probas(self, obs: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Retourne les probabilités de chaque action via softmax des logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: Vecteur d'observation (obs_dim,) en float32
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de forme (3,) : [P(HOLD), P(LONG), P(SHORT)]
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||||
|
return np.array([1.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
|
||||||
|
self.network.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||||
|
logits, _ = self.network(obs_t)
|
||||||
|
probas = torch.softmax(logits, dim=-1).squeeze(0).numpy()
|
||||||
|
|
||||||
|
return probas.astype(np.float32)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Sauvegarde / Chargement
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def save(self, symbol: str = 'EURUSD', timeframe: str = '1h') -> Path:
|
||||||
|
"""
|
||||||
|
Sauvegarde le modèle et ses métadonnées sur disque.
|
||||||
|
|
||||||
|
Format :
|
||||||
|
{symbol}_{timeframe}.pt — state_dict + config PyTorch
|
||||||
|
{symbol}_{timeframe}_meta.json — métadonnées JSON
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire tradée (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path vers le fichier .pt sauvegardé
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError si PyTorch non disponible ou modèle non entraîné
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise RuntimeError("PyTorch non disponible")
|
||||||
|
if not self.is_trained or self.network is None:
|
||||||
|
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||||
|
|
||||||
|
model_id = f"{symbol}_{timeframe}"
|
||||||
|
model_path = self.MODELS_DIR / f"{model_id}.pt"
|
||||||
|
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
|
||||||
|
|
||||||
|
torch.save(
|
||||||
|
{
|
||||||
|
'state_dict': self.network.state_dict(),
|
||||||
|
'config': {
|
||||||
|
'obs_dim': self.obs_dim,
|
||||||
|
'n_actions': self.n_actions,
|
||||||
|
'lr': self.lr,
|
||||||
|
'n_steps': self.n_steps,
|
||||||
|
'n_epochs': self.n_epochs,
|
||||||
|
'batch_size': self.batch_size,
|
||||||
|
'gamma': self.gamma,
|
||||||
|
'gae_lambda': self.gae_lambda,
|
||||||
|
'clip_eps': self.clip_eps,
|
||||||
|
'entropy_coef': self.entropy_coef,
|
||||||
|
'value_coef': self.value_coef,
|
||||||
|
'max_grad_norm': self.max_grad_norm,
|
||||||
|
},
|
||||||
|
'metadata': self.metadata,
|
||||||
|
},
|
||||||
|
model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
json.dump(self.metadata, f, indent=2, default=str)
|
||||||
|
|
||||||
|
logger.info(f"Modèle PPO sauvegardé : {model_path}")
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, symbol: str, timeframe: str) -> 'PPOModel':
|
||||||
|
"""
|
||||||
|
Charge un modèle PPO existant depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance PPOModel prête à prédire
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError si PyTorch non disponible
|
||||||
|
FileNotFoundError si le modèle n'existe pas
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise RuntimeError("PyTorch non disponible")
|
||||||
|
|
||||||
|
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Modèle PPO non trouvé : {model_path}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(model_path, map_location='cpu')
|
||||||
|
cfg = checkpoint.get('config', {})
|
||||||
|
|
||||||
|
instance = cls(
|
||||||
|
obs_dim = cfg.get('obs_dim', 20),
|
||||||
|
n_actions = cfg.get('n_actions', 3),
|
||||||
|
lr = cfg.get('lr', 3e-4),
|
||||||
|
n_steps = cfg.get('n_steps', 2048),
|
||||||
|
n_epochs = cfg.get('n_epochs', 10),
|
||||||
|
batch_size = cfg.get('batch_size', 64),
|
||||||
|
gamma = cfg.get('gamma', 0.99),
|
||||||
|
gae_lambda = cfg.get('gae_lambda', 0.95),
|
||||||
|
clip_eps = cfg.get('clip_eps', 0.2),
|
||||||
|
entropy_coef = cfg.get('entropy_coef', 0.01),
|
||||||
|
value_coef = cfg.get('value_coef', 0.5),
|
||||||
|
max_grad_norm = cfg.get('max_grad_norm', 0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
instance.network.load_state_dict(checkpoint['state_dict'])
|
||||||
|
instance.network.eval()
|
||||||
|
instance.is_trained = True
|
||||||
|
instance.metadata = checkpoint.get('metadata', {})
|
||||||
|
|
||||||
|
logger.info(f"Modèle PPO chargé depuis {model_path}")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Calcul des avantages (GAE)
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _compute_gae(
|
||||||
|
self,
|
||||||
|
rewards: np.ndarray,
|
||||||
|
values: np.ndarray,
|
||||||
|
dones: np.ndarray,
|
||||||
|
last_value: float,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Calcule les avantages généralisés (GAE) et les retours cibles.
|
||||||
|
|
||||||
|
Formule GAE :
|
||||||
|
δt = rt + γ·V(st+1)·(1-dt) - V(st)
|
||||||
|
Ât = δt + γλ·Ât+1·(1-dt)
|
||||||
|
Rt = Ât + V(st)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rewards: Récompenses du rollout (n_steps,)
|
||||||
|
values: Valeurs estimées par le critic (n_steps,)
|
||||||
|
dones: Indicateurs de fin d'épisode (n_steps,)
|
||||||
|
last_value: Valeur bootstrap du dernier état (scalaire)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (advantages, returns) de forme (n_steps,)
|
||||||
|
"""
|
||||||
|
n = len(rewards)
|
||||||
|
advantages = np.zeros(n, dtype=np.float32)
|
||||||
|
last_gae = 0.0
|
||||||
|
|
||||||
|
for t in reversed(range(n)):
|
||||||
|
if t == n - 1:
|
||||||
|
next_non_terminal = 1.0 - dones[t]
|
||||||
|
next_value = last_value
|
||||||
|
else:
|
||||||
|
next_non_terminal = 1.0 - dones[t]
|
||||||
|
next_value = values[t + 1]
|
||||||
|
|
||||||
|
delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
|
||||||
|
last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
|
||||||
|
advantages[t] = last_gae
|
||||||
|
|
||||||
|
returns = advantages + values
|
||||||
|
return advantages, returns
|
||||||
557
src/ml/rl/rl_strategy_model.py
Normal file
557
src/ml/rl/rl_strategy_model.py
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
"""
|
||||||
|
RLStrategyModel — Interface unifiée pour l'agent PPO de trading.
|
||||||
|
|
||||||
|
Interface identique à MLStrategyModel et CNNImageStrategyModel :
|
||||||
|
model = RLStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||||
|
result = model.train(df_ohlcv)
|
||||||
|
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||||
|
model.save()
|
||||||
|
model.load(symbol, timeframe)
|
||||||
|
model.list_trained_models()
|
||||||
|
model.get_feature_importance() # retourne [] (RL = boîte noire)
|
||||||
|
|
||||||
|
Pipeline d'entraînement :
|
||||||
|
1. Validation des données OHLCV (minimum 500 barres recommandées)
|
||||||
|
2. Création de TradingEnv sur les données d'entraînement
|
||||||
|
3. Entraînement PPO (total_timesteps configurable)
|
||||||
|
4. Évaluation sur holdout (20% temporel final)
|
||||||
|
5. Sauvegarde du modèle PPO + métadonnées JSON
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE, N_FEATURES
|
||||||
|
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE, MODELS_DIR
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Décodage des actions RL en signaux de trading
|
||||||
|
# Action 0 (HOLD) → signal 0 (NEUTRAL)
|
||||||
|
# Action 1 (LONG) → signal 1 (LONG)
|
||||||
|
# Action 2 (SHORT) → signal -1 (SHORT)
|
||||||
|
ACTION_TO_SIGNAL = {0: 0, 1: 1, 2: -1}
|
||||||
|
|
||||||
|
# Disponibilité globale du module RL
|
||||||
|
RL_AVAILABLE = TORCH_AVAILABLE and GYM_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
|
class RLStrategyModel:
|
||||||
|
"""
|
||||||
|
Modèle de trading basé sur un agent PPO (Reinforcement Learning).
|
||||||
|
|
||||||
|
L'agent apprend directement à maximiser le PnL via interaction avec
|
||||||
|
un environnement de trading simulé (TradingEnv), sans supervision
|
||||||
|
explicite de labels.
|
||||||
|
|
||||||
|
Avantages par rapport aux modèles supervisés :
|
||||||
|
- Pas besoin de labels manuels (LONG/SHORT/NEUTRAL)
|
||||||
|
- Optimise directement l'objectif de trading (PnL)
|
||||||
|
- Apprend à gérer les positions (durée, timing de sortie)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire tradée (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h', '15m')
|
||||||
|
total_timesteps: Nombre de timesteps d'entraînement PPO (défaut: 100_000)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
|
||||||
|
min_confidence: Seuil de confiance minimum pour signaux (défaut: 0.50)
|
||||||
|
train_ratio: Fraction des données pour l'entraînement (défaut: 0.80)
|
||||||
|
initial_capital: Capital initial pour la simulation (défaut: 10_000)
|
||||||
|
lr: Taux d'apprentissage Adam (défaut: 3e-4)
|
||||||
|
n_steps: Transitions par rollout (défaut: 2048)
|
||||||
|
n_epochs: Passes d'optimisation par rollout (défaut: 10)
|
||||||
|
batch_size: Taille des mini-batchs (défaut: 64)
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODELS_DIR = MODELS_DIR
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
symbol: str = 'EURUSD',
|
||||||
|
timeframe: str = '1h',
|
||||||
|
total_timesteps: int = 100_000,
|
||||||
|
sl_atr_mult: float = 1.0,
|
||||||
|
tp_atr_mult: float = 2.0,
|
||||||
|
min_confidence: float = 0.50,
|
||||||
|
train_ratio: float = 0.80,
|
||||||
|
initial_capital: float = 10_000.0,
|
||||||
|
lr: float = 3e-4,
|
||||||
|
n_steps: int = 2048,
|
||||||
|
n_epochs: int = 10,
|
||||||
|
batch_size: int = 64,
|
||||||
|
):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.timeframe = timeframe
|
||||||
|
self.total_timesteps = total_timesteps
|
||||||
|
self.sl_atr_mult = sl_atr_mult
|
||||||
|
self.tp_atr_mult = tp_atr_mult
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
self.train_ratio = train_ratio
|
||||||
|
self.initial_capital = initial_capital
|
||||||
|
self.lr = lr
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.n_epochs = n_epochs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
self.ppo_model: Optional[PPOModel] = None
|
||||||
|
self.is_trained = False
|
||||||
|
self.metadata: Dict = {}
|
||||||
|
|
||||||
|
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Entraînement
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def train(self, data: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Entraîne l'agent PPO sur les données OHLCV fournies.
|
||||||
|
|
||||||
|
Étapes :
|
||||||
|
1. Validation des données (minimum 200 barres)
|
||||||
|
2. Split temporel train/holdout (80/20 par défaut)
|
||||||
|
3. Création du TradingEnv sur le jeu d'entraînement
|
||||||
|
4. Entraînement PPO (total_timesteps)
|
||||||
|
5. Évaluation sur le jeu de holdout
|
||||||
|
6. Sauvegarde automatique du modèle
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame OHLCV (colonnes : open/high/low/close/volume)
|
||||||
|
Minimum 200 barres, idéalement 2000+ pour un bon apprentissage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict avec métriques :
|
||||||
|
- total_timesteps, n_episodes, mean_ep_return (entraînement)
|
||||||
|
- eval_total_pnl, eval_return_pct, eval_sharpe (évaluation holdout)
|
||||||
|
- n_samples, trained_at, error (si échec)
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||||
|
if not GYM_AVAILABLE:
|
||||||
|
return {'error': 'gymnasium non disponible — installer gymnasium>=0.26'}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Début entraînement RLStrategyModel pour "
|
||||||
|
f"{self.symbol}/{self.timeframe} — total_timesteps={self.total_timesteps}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Validation et normalisation des données
|
||||||
|
df = data.copy()
|
||||||
|
df.columns = [c.lower() for c in df.columns]
|
||||||
|
|
||||||
|
required_cols = {'open', 'high', 'low', 'close', 'volume'}
|
||||||
|
missing = required_cols - set(df.columns)
|
||||||
|
if missing:
|
||||||
|
return {'error': f'Colonnes manquantes : {missing}'}
|
||||||
|
|
||||||
|
df = df.dropna(subset=list(required_cols))
|
||||||
|
|
||||||
|
if len(df) < 200:
|
||||||
|
return {
|
||||||
|
'error': f'Données insuffisantes : {len(df)} barres (minimum 200)'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Split temporel train / holdout (respecte l'ordre chronologique)
|
||||||
|
n_train = int(len(df) * self.train_ratio)
|
||||||
|
df_train = df.iloc[:n_train].reset_index(drop=True)
|
||||||
|
df_holdout = df.iloc[n_train:].reset_index(drop=True)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" Split : train={len(df_train)} barres, holdout={len(df_holdout)} barres"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(df_train) < 100:
|
||||||
|
return {'error': f'Données d\'entraînement insuffisantes : {len(df_train)} barres'}
|
||||||
|
|
||||||
|
# 3. Création de l'environnement d'entraînement
|
||||||
|
train_env = TradingEnv(
|
||||||
|
df = df_train,
|
||||||
|
sl_atr_mult = self.sl_atr_mult,
|
||||||
|
tp_atr_mult = self.tp_atr_mult,
|
||||||
|
initial_capital = self.initial_capital,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Initialisation et entraînement PPO
|
||||||
|
self.ppo_model = PPOModel(
|
||||||
|
obs_dim = N_FEATURES,
|
||||||
|
n_actions = 3,
|
||||||
|
lr = self.lr,
|
||||||
|
n_steps = min(self.n_steps, len(df_train) - 30), # Sécurité
|
||||||
|
n_epochs = self.n_epochs,
|
||||||
|
batch_size = self.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_metrics = self.ppo_model.train(
|
||||||
|
env = train_env,
|
||||||
|
total_timesteps = self.total_timesteps,
|
||||||
|
symbol = self.symbol,
|
||||||
|
timeframe = self.timeframe,
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'error' in train_metrics:
|
||||||
|
return train_metrics
|
||||||
|
|
||||||
|
self.is_trained = True
|
||||||
|
|
||||||
|
# 5. Évaluation sur holdout
|
||||||
|
eval_metrics = {}
|
||||||
|
if len(df_holdout) >= 50:
|
||||||
|
eval_metrics = self._evaluate_on_holdout(df_holdout)
|
||||||
|
logger.info(
|
||||||
|
f" Évaluation holdout : "
|
||||||
|
f"PnL={eval_metrics.get('eval_total_pnl', 0):.4f}, "
|
||||||
|
f"Return={eval_metrics.get('eval_return_pct', 0):.2%}, "
|
||||||
|
f"Sharpe≈{eval_metrics.get('eval_sharpe', 0):.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Assemblage des métadonnées
|
||||||
|
self.metadata = {
|
||||||
|
'symbol': self.symbol,
|
||||||
|
'timeframe': self.timeframe,
|
||||||
|
'trained_at': datetime.utcnow().isoformat(),
|
||||||
|
'n_samples': len(df),
|
||||||
|
'n_train': len(df_train),
|
||||||
|
'n_holdout': len(df_holdout),
|
||||||
|
'total_timesteps': train_metrics.get('total_timesteps', self.total_timesteps),
|
||||||
|
'n_episodes': train_metrics.get('n_episodes', 0),
|
||||||
|
'n_updates': train_metrics.get('n_updates', 0),
|
||||||
|
'mean_ep_return': train_metrics.get('mean_ep_return', 0.0),
|
||||||
|
'policy_loss': train_metrics.get('policy_loss', 0.0),
|
||||||
|
'value_loss': train_metrics.get('value_loss', 0.0),
|
||||||
|
'entropy_loss': train_metrics.get('entropy_loss', 0.0),
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'min_confidence': self.min_confidence,
|
||||||
|
'hyperparams': train_metrics.get('hyperparams', {}),
|
||||||
|
**eval_metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 7. Sauvegarde automatique
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Entraînement RL terminé pour {self.symbol}/{self.timeframe} — "
|
||||||
|
f"N_episodes={self.metadata['n_episodes']}, "
|
||||||
|
f"MeanReturn={self.metadata['mean_ep_return']:.4f}"
|
||||||
|
)
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Prédiction
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def predict(self, data: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Prédit le signal de trading pour la dernière barre disponible.
|
||||||
|
|
||||||
|
L'agent évalue les 20 dernières barres (fenêtre lookback) et retourne
|
||||||
|
son action déterministe (HOLD/LONG/SHORT) avec les probabilités associées.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame OHLCV récent (minimum 25 barres pour les indicateurs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict : {
|
||||||
|
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL/HOLD),
|
||||||
|
'confidence': float [0..1] — max(P_LONG, P_SHORT),
|
||||||
|
'probas': {'hold': float, 'long': float, 'short': float},
|
||||||
|
'tradeable': bool — confidence >= min_confidence et signal != 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': 'PyTorch non disponible'
|
||||||
|
}
|
||||||
|
if not self.is_trained or self.ppo_model is None:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': 'Modèle non entraîné — appeler train() d\'abord'
|
||||||
|
}
|
||||||
|
|
||||||
|
df = data.copy()
|
||||||
|
df.columns = [c.lower() for c in df.columns]
|
||||||
|
|
||||||
|
if len(df) < 25:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': f'Pas assez de données : {len(df)} barres (minimum 25)'
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Créer un environnement temporaire pour obtenir l'observation
|
||||||
|
temp_env = TradingEnv(
|
||||||
|
df = df,
|
||||||
|
sl_atr_mult = self.sl_atr_mult,
|
||||||
|
tp_atr_mult = self.tp_atr_mult,
|
||||||
|
initial_capital = self.initial_capital,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Positionner l'environnement à la dernière barre
|
||||||
|
temp_env._current_step = len(df) - 1
|
||||||
|
obs = temp_env._get_observation()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||||
|
'error': f'Erreur construction observation : {e}'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prédiction déterministe (argmax des logits)
|
||||||
|
action = self.ppo_model.predict_deterministic(obs)
|
||||||
|
probas = self.ppo_model.get_action_probas(obs) # [P(HOLD), P(LONG), P(SHORT)]
|
||||||
|
|
||||||
|
signal = ACTION_TO_SIGNAL.get(action, 0)
|
||||||
|
confidence = float(max(probas[1], probas[2])) # max(P_LONG, P_SHORT)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'signal': signal,
|
||||||
|
'confidence': confidence,
|
||||||
|
'probas': {
|
||||||
|
'hold': float(probas[0]),
|
||||||
|
'long': float(probas[1]),
|
||||||
|
'short': float(probas[2]),
|
||||||
|
},
|
||||||
|
'tradeable': confidence >= self.min_confidence and signal != 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Sauvegarde / Chargement
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def save(self) -> Path:
|
||||||
|
"""
|
||||||
|
Sauvegarde le modèle PPO et les métadonnées sur disque.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path vers le fichier .pt sauvegardé
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError si le modèle n'est pas entraîné
|
||||||
|
"""
|
||||||
|
if not self.is_trained or self.ppo_model is None:
|
||||||
|
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||||
|
|
||||||
|
# Sauvegarde du réseau PPO
|
||||||
|
model_path = self.ppo_model.save(self.symbol, self.timeframe)
|
||||||
|
|
||||||
|
# Sauvegarde des métadonnées complètes (incluant les params RLStrategyModel)
|
||||||
|
meta = self.metadata.copy()
|
||||||
|
meta.update({
|
||||||
|
'rl_config': {
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'min_confidence': self.min_confidence,
|
||||||
|
'train_ratio': self.train_ratio,
|
||||||
|
'initial_capital': self.initial_capital,
|
||||||
|
'total_timesteps': self.total_timesteps,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
meta_path = self.MODELS_DIR / f"{self.symbol}_{self.timeframe}_meta.json"
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
json.dump(meta, f, indent=2, default=str)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, symbol: str, timeframe: str) -> 'RLStrategyModel':
|
||||||
|
"""
|
||||||
|
Charge un modèle RL existant depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (ex: 'EURUSD')
|
||||||
|
timeframe: Timeframe (ex: '1h')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance RLStrategyModel prête à prédire
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError si PyTorch ou gymnasium non disponible
|
||||||
|
FileNotFoundError si le modèle n'existe pas
|
||||||
|
"""
|
||||||
|
if not TORCH_AVAILABLE:
|
||||||
|
raise RuntimeError("PyTorch non disponible")
|
||||||
|
|
||||||
|
meta_path = MODELS_DIR / f"{symbol}_{timeframe}_meta.json"
|
||||||
|
rl_config = {}
|
||||||
|
if meta_path.exists():
|
||||||
|
try:
|
||||||
|
with open(meta_path) as f:
|
||||||
|
saved_meta = json.load(f)
|
||||||
|
rl_config = saved_meta.get('rl_config', {})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
instance = cls(
|
||||||
|
symbol = symbol,
|
||||||
|
timeframe = timeframe,
|
||||||
|
sl_atr_mult = rl_config.get('sl_atr_mult', 1.0),
|
||||||
|
tp_atr_mult = rl_config.get('tp_atr_mult', 2.0),
|
||||||
|
min_confidence = rl_config.get('min_confidence', 0.50),
|
||||||
|
train_ratio = rl_config.get('train_ratio', 0.80),
|
||||||
|
initial_capital = rl_config.get('initial_capital', 10_000.0),
|
||||||
|
total_timesteps = rl_config.get('total_timesteps', 100_000),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chargement du modèle PPO
|
||||||
|
instance.ppo_model = PPOModel.load(symbol, timeframe)
|
||||||
|
instance.is_trained = True
|
||||||
|
|
||||||
|
if meta_path.exists():
|
||||||
|
try:
|
||||||
|
with open(meta_path) as f:
|
||||||
|
instance.metadata = json.load(f)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info(f"Modèle RL chargé depuis {MODELS_DIR}/{symbol}_{timeframe}.pt")
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_trained_models() -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Retourne la liste des modèles RL entraînés disponibles.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Liste de dicts avec symbol, timeframe, trained_at, n_samples,
|
||||||
|
mean_ep_return, eval_return_pct
|
||||||
|
"""
|
||||||
|
if not MODELS_DIR.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for f in MODELS_DIR.glob("*_meta.json"):
|
||||||
|
try:
|
||||||
|
with open(f) as fp:
|
||||||
|
meta = json.load(fp)
|
||||||
|
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
|
||||||
|
parts = stem.split('_', 1)
|
||||||
|
models.append({
|
||||||
|
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||||
|
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||||
|
'trained_at': meta.get('trained_at', '?'),
|
||||||
|
'n_samples': meta.get('n_samples', 0),
|
||||||
|
'total_timesteps': meta.get('total_timesteps', 0),
|
||||||
|
'n_episodes': meta.get('n_episodes', 0),
|
||||||
|
'mean_ep_return': meta.get('mean_ep_return', 0.0),
|
||||||
|
'eval_return_pct': meta.get('eval_return_pct', 0.0),
|
||||||
|
'eval_sharpe': meta.get('eval_sharpe', 0.0),
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def get_feature_importance(self) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Retourne l'importance des features.
|
||||||
|
|
||||||
|
Note : Les agents RL (PPO) sont des boîtes noires — pas d'importance
|
||||||
|
de feature interprétable. Retourne les noms des 20 features pour
|
||||||
|
information, sans score numérique.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Liste de dicts avec feature_name et description
|
||||||
|
"""
|
||||||
|
feature_names = [
|
||||||
|
{'rank': 1, 'feature': 'body_ratio', 'description': '(close-open)/ATR — corps de la bougie'},
|
||||||
|
{'rank': 2, 'feature': 'amplitude', 'description': '(high-low)/ATR — amplitude de la bougie'},
|
||||||
|
{'rank': 3, 'feature': 'momentum_1', 'description': '(close-close[-1])/ATR — momentum 1 barre'},
|
||||||
|
{'rank': 4, 'feature': 'momentum_5', 'description': '(close-close[-5])/ATR — momentum 5 barres'},
|
||||||
|
{'rank': 5, 'feature': 'rsi_norm', 'description': 'RSI(14)/100 — RSI normalisé'},
|
||||||
|
{'rank': 6, 'feature': 'bb_position', 'description': '(close-BB_upper)/ATR — position vs Bollinger'},
|
||||||
|
{'rank': 7, 'feature': 'bb_width', 'description': '(BB_upper-BB_lower)/ATR — largeur bandes'},
|
||||||
|
{'rank': 8, 'feature': 'macd_norm', 'description': 'MACD/ATR — MACD normalisé'},
|
||||||
|
{'rank': 9, 'feature': 'macd_signal_norm', 'description': 'Signal MACD/ATR — signal MACD normalisé'},
|
||||||
|
{'rank': 10, 'feature': 'vol_ratio', 'description': 'Volume/MA20(Volume) — ratio volume'},
|
||||||
|
{'rank': 11, 'feature': 'position', 'description': 'Position courante (-1=short, 0=flat, 1=long)'},
|
||||||
|
{'rank': 12, 'feature': 'unrealized_pnl', 'description': 'PnL_non_réalisé/ATR — gain/perte courant'},
|
||||||
|
{'rank': 13, 'feature': 'bars_in_trade', 'description': 'Durée du trade normalisée'},
|
||||||
|
{'rank': 14, 'feature': 'hour', 'description': 'Heure de la journée normalisée'},
|
||||||
|
{'rank': 15, 'feature': 'day_of_week', 'description': 'Jour de semaine normalisé'},
|
||||||
|
{'rank': 16, 'feature': 'ema9_dist', 'description': '(close-EMA9)/ATR — distance EMA9'},
|
||||||
|
{'rank': 17, 'feature': 'ema21_dist', 'description': '(close-EMA21)/ATR — distance EMA21'},
|
||||||
|
{'rank': 18, 'feature': 'ema9_ema21_cross', 'description': '(EMA9-EMA21)/ATR — croisement EMA court'},
|
||||||
|
{'rank': 19, 'feature': 'ema50_dist', 'description': '(close-EMA50)/ATR — distance EMA50'},
|
||||||
|
{'rank': 20, 'feature': 'ema200_dist', 'description': '(close-EMA200)/ATR — distance EMA200'},
|
||||||
|
]
|
||||||
|
return feature_names
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Évaluation interne
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _evaluate_on_holdout(self, df_holdout: pd.DataFrame) -> Dict:
|
||||||
|
"""
|
||||||
|
Évalue l'agent entraîné sur le jeu de holdout.
|
||||||
|
|
||||||
|
Lance un épisode complet sur les données de holdout en mode déterministe
|
||||||
|
(argmax des logits, pas de sampling). Retourne les métriques de performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_holdout: DataFrame OHLCV du jeu de holdout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict avec eval_total_pnl, eval_return_pct, eval_sharpe,
|
||||||
|
eval_n_trades, eval_win_rate
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
eval_env = TradingEnv(
|
||||||
|
df = df_holdout,
|
||||||
|
sl_atr_mult = self.sl_atr_mult,
|
||||||
|
tp_atr_mult = self.tp_atr_mult,
|
||||||
|
initial_capital = self.initial_capital,
|
||||||
|
)
|
||||||
|
|
||||||
|
obs, _ = eval_env.reset()
|
||||||
|
done = False
|
||||||
|
rewards_hist = []
|
||||||
|
n_trades = 0
|
||||||
|
win_trades = 0
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
action = self.ppo_model.predict_deterministic(obs)
|
||||||
|
obs, reward, terminated, truncated, info = eval_env.step(action)
|
||||||
|
done = terminated or truncated
|
||||||
|
|
||||||
|
if abs(reward) > 0.001: # Fermeture de position
|
||||||
|
n_trades += 1
|
||||||
|
if reward > 0:
|
||||||
|
win_trades += 1
|
||||||
|
|
||||||
|
rewards_hist.append(reward)
|
||||||
|
|
||||||
|
stats = eval_env.get_episode_stats()
|
||||||
|
win_rate = win_trades / max(n_trades, 1)
|
||||||
|
rewards_arr = np.array(rewards_hist)
|
||||||
|
sharpe = (
|
||||||
|
float(rewards_arr.mean() / (rewards_arr.std() + 1e-8) * np.sqrt(252))
|
||||||
|
if len(rewards_arr) > 1 else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'eval_total_pnl': float(stats.get('total_pnl', 0.0)),
|
||||||
|
'eval_return_pct': float(stats.get('return_pct', 0.0)),
|
||||||
|
'eval_sharpe': float(sharpe),
|
||||||
|
'eval_n_trades': n_trades,
|
||||||
|
'eval_win_rate': float(win_rate),
|
||||||
|
'eval_n_steps': int(stats.get('n_steps', 0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Évaluation holdout échouée : {e}")
|
||||||
|
return {
|
||||||
|
'eval_total_pnl': 0.0,
|
||||||
|
'eval_return_pct': 0.0,
|
||||||
|
'eval_sharpe': 0.0,
|
||||||
|
'eval_n_trades': 0,
|
||||||
|
'eval_win_rate': 0.0,
|
||||||
|
}
|
||||||
524
src/ml/rl/trading_env.py
Normal file
524
src/ml/rl/trading_env.py
Normal file
@@ -0,0 +1,524 @@
|
|||||||
|
"""
|
||||||
|
TradingEnv — Environnement Gymnasium pour l'agent RL de trading.
|
||||||
|
|
||||||
|
Conforme à l'API gymnasium (reset/step) et adapté aux marchés forex.
|
||||||
|
|
||||||
|
Espace d'observation : 20 features normalisées (fenêtre glissante de 20 barres)
|
||||||
|
Espace d'action : Discrete(3) → {0=HOLD, 1=LONG, 2=SHORT}
|
||||||
|
Récompense : PnL réalisé + pénalité drawdown + bonus Sharpe
|
||||||
|
|
||||||
|
L'environnement gère :
|
||||||
|
- Les transitions de position (flat → long, flat → short, inversions)
|
||||||
|
- Le stop-loss / take-profit automatiques basés sur l'ATR
|
||||||
|
- La normalisation des observations par l'ATR
|
||||||
|
- La fenêtre glissante de 20 barres (lookback)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Imports conditionnels gymnasium / gym
|
||||||
|
try:
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
GYM_AVAILABLE = True
|
||||||
|
GYM_MODULE = 'gymnasium'
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import gym
|
||||||
|
from gym import spaces
|
||||||
|
GYM_AVAILABLE = True
|
||||||
|
GYM_MODULE = 'gym'
|
||||||
|
except ImportError:
|
||||||
|
gym = None
|
||||||
|
spaces = None
|
||||||
|
GYM_AVAILABLE = False
|
||||||
|
GYM_MODULE = None
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Actions
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
ACTION_HOLD = 0
|
||||||
|
ACTION_LONG = 1
|
||||||
|
ACTION_SHORT = 2
|
||||||
|
|
||||||
|
# Nombre de features dans le vecteur d'observation
|
||||||
|
N_FEATURES = 20
|
||||||
|
|
||||||
|
# Fenêtre lookback (nombre de barres passées incluses dans l'observation)
|
||||||
|
LOOKBACK = 20
|
||||||
|
|
||||||
|
|
||||||
|
class TradingEnv:
|
||||||
|
"""
|
||||||
|
Environnement de trading conforme à l'API gymnasium pour l'agent PPO.
|
||||||
|
|
||||||
|
Chaque step() représente la décision de trading à la fermeture d'une bougie.
|
||||||
|
L'agent reçoit 20 features normalisées décrivant le contexte de marché et
|
||||||
|
sa position courante, puis choisit HOLD / LONG / SHORT.
|
||||||
|
|
||||||
|
Gestion des positions :
|
||||||
|
- Une seule position à la fois (pas de pyramiding)
|
||||||
|
- Inversion directe possible (SHORT → LONG sans passer par FLAT)
|
||||||
|
- SL/TP ATR-based (sl_atr_mult×ATR et tp_atr_mult×ATR)
|
||||||
|
|
||||||
|
Récompense :
|
||||||
|
- PnL réalisé à la clôture de position (en multiple d'ATR)
|
||||||
|
- Pénalité proportionnelle au drawdown courant
|
||||||
|
- Pas de bonus pour maintenir une position (évite le surtrading)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame OHLCV (colonnes minuscules : open/high/low/close/volume)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
|
||||||
|
atr_period: Période de calcul de l'ATR (défaut: 14)
|
||||||
|
initial_capital: Capital initial (pour calcul drawdown) (défaut: 10000)
|
||||||
|
drawdown_penalty: Coefficient de pénalité drawdown (défaut: 0.1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {'render_modes': []}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
sl_atr_mult: float = 1.0,
|
||||||
|
tp_atr_mult: float = 2.0,
|
||||||
|
atr_period: int = 14,
|
||||||
|
initial_capital: float = 10_000.0,
|
||||||
|
drawdown_penalty: float = 0.1,
|
||||||
|
):
|
||||||
|
if not GYM_AVAILABLE:
|
||||||
|
raise RuntimeError(
|
||||||
|
"gymnasium ou gym requis — installer gymnasium>=0.26 ou gym>=0.21"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.df = df.copy().reset_index(drop=True)
|
||||||
|
self.df.columns = [c.lower() for c in self.df.columns]
|
||||||
|
self.sl_atr_mult = sl_atr_mult
|
||||||
|
self.tp_atr_mult = tp_atr_mult
|
||||||
|
self.atr_period = atr_period
|
||||||
|
self.initial_capital = initial_capital
|
||||||
|
self.drawdown_penalty = drawdown_penalty
|
||||||
|
|
||||||
|
# Calcul de l'ATR sur toute la série (optimisation : évite le recalcul dans step)
|
||||||
|
self._atr_series = self._compute_atr_series()
|
||||||
|
|
||||||
|
# Calcul des EMAs sur toute la série
|
||||||
|
self._ema9 = self.df['close'].ewm(span=9, adjust=False).mean()
|
||||||
|
self._ema21 = self.df['close'].ewm(span=21, adjust=False).mean()
|
||||||
|
self._ema50 = self.df['close'].ewm(span=50, adjust=False).mean()
|
||||||
|
self._ema200 = self.df['close'].ewm(span=200, adjust=False).mean()
|
||||||
|
|
||||||
|
# Calcul des Bandes de Bollinger (20, 2σ)
|
||||||
|
bb_ma = self.df['close'].rolling(20).mean()
|
||||||
|
bb_std = self.df['close'].rolling(20).std()
|
||||||
|
self._bb_upper = bb_ma + 2 * bb_std
|
||||||
|
self._bb_lower = bb_ma - 2 * bb_std
|
||||||
|
|
||||||
|
# Calcul du RSI(14)
|
||||||
|
self._rsi = self._compute_rsi(period=14)
|
||||||
|
|
||||||
|
# Calcul du MACD (12, 26, 9)
|
||||||
|
ema12 = self.df['close'].ewm(span=12, adjust=False).mean()
|
||||||
|
ema26 = self.df['close'].ewm(span=26, adjust=False).mean()
|
||||||
|
self._macd = ema12 - ema26
|
||||||
|
self._macd_sig = self._macd.ewm(span=9, adjust=False).mean()
|
||||||
|
|
||||||
|
# Volume MA(20) pour normalisation
|
||||||
|
self._vol_ma20 = self.df['volume'].rolling(20).mean().fillna(
|
||||||
|
self.df['volume'].mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Espaces gymnasium
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low = -10.0,
|
||||||
|
high = 10.0,
|
||||||
|
shape = (N_FEATURES,),
|
||||||
|
dtype = np.float32,
|
||||||
|
)
|
||||||
|
self.action_space = spaces.Discrete(3)
|
||||||
|
|
||||||
|
# État interne (initialisé dans reset())
|
||||||
|
self._current_step = LOOKBACK
|
||||||
|
self._position = 0 # -1=short, 0=flat, 1=long
|
||||||
|
self._entry_price = 0.0
|
||||||
|
self._entry_atr = 0.0
|
||||||
|
self._bars_in_trade = 0
|
||||||
|
self._capital = initial_capital
|
||||||
|
self._peak_capital = initial_capital
|
||||||
|
self._total_pnl = 0.0
|
||||||
|
self._pnl_history = [] # Pour calcul Sharpe en fin d'épisode
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Interface gymnasium
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Réinitialise l'environnement au début d'un épisode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: Graine aléatoire (ignorée ici, données déterministes)
|
||||||
|
options: Options additionnelles (non utilisées)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (observation, info) conforme gymnasium
|
||||||
|
"""
|
||||||
|
if seed is not None:
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
self._current_step = LOOKBACK
|
||||||
|
self._position = 0
|
||||||
|
self._entry_price = 0.0
|
||||||
|
self._entry_atr = 0.0
|
||||||
|
self._bars_in_trade = 0
|
||||||
|
self._capital = self.initial_capital
|
||||||
|
self._peak_capital = self.initial_capital
|
||||||
|
self._total_pnl = 0.0
|
||||||
|
self._pnl_history = []
|
||||||
|
self._done = False
|
||||||
|
|
||||||
|
obs = self._get_observation()
|
||||||
|
info = {}
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
|
||||||
|
"""
|
||||||
|
Exécute une action et retourne la transition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: 0=HOLD, 1=LONG, 2=SHORT
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple (observation, reward, terminated, truncated, info) conforme gymnasium
|
||||||
|
"""
|
||||||
|
if self._done:
|
||||||
|
obs, info = self.reset()
|
||||||
|
return obs, 0.0, True, False, info
|
||||||
|
|
||||||
|
reward = 0.0
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
current_close = float(self.df['close'].iloc[self._current_step])
|
||||||
|
current_atr = float(self._atr_series.iloc[self._current_step])
|
||||||
|
if current_atr <= 0:
|
||||||
|
current_atr = float(self.df['close'].iloc[self._current_step]) * 0.001
|
||||||
|
|
||||||
|
# ── Vérification SL/TP avant d'appliquer la nouvelle action ──────────
|
||||||
|
if self._position != 0:
|
||||||
|
reward += self._check_sl_tp(current_close, current_atr)
|
||||||
|
|
||||||
|
# ── Application de la nouvelle action ─────────────────────────────────
|
||||||
|
desired_position = self._action_to_position(action)
|
||||||
|
|
||||||
|
if desired_position != self._position:
|
||||||
|
# Fermeture de la position courante (si ouverte)
|
||||||
|
if self._position != 0:
|
||||||
|
close_reward = self._close_position(current_close, current_atr)
|
||||||
|
reward += close_reward
|
||||||
|
|
||||||
|
# Ouverture d'une nouvelle position (si pas HOLD)
|
||||||
|
if desired_position != 0:
|
||||||
|
self._open_position(desired_position, current_close, current_atr)
|
||||||
|
else:
|
||||||
|
# Même position : incrémenter le compteur de barres
|
||||||
|
if self._position != 0:
|
||||||
|
self._bars_in_trade += 1
|
||||||
|
|
||||||
|
# ── Pénalité drawdown ─────────────────────────────────────────────────
|
||||||
|
if self._capital > self._peak_capital:
|
||||||
|
self._peak_capital = self._capital
|
||||||
|
drawdown = (self._peak_capital - self._capital) / max(self._peak_capital, 1.0)
|
||||||
|
if drawdown > 0:
|
||||||
|
reward -= self.drawdown_penalty * drawdown
|
||||||
|
|
||||||
|
# ── Avance d'une barre ────────────────────────────────────────────────
|
||||||
|
self._current_step += 1
|
||||||
|
terminated = self._current_step >= len(self.df) - 1
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
if terminated:
|
||||||
|
# Fermeture forcée de la position à la fin de l'épisode
|
||||||
|
if self._position != 0:
|
||||||
|
final_close = float(self.df['close'].iloc[-1])
|
||||||
|
final_atr = float(self._atr_series.iloc[-1])
|
||||||
|
reward += self._close_position(final_close, final_atr)
|
||||||
|
self._done = True
|
||||||
|
|
||||||
|
obs = self._get_observation()
|
||||||
|
|
||||||
|
# Sauvegarde de la récompense pour calcul Sharpe
|
||||||
|
self._pnl_history.append(reward)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
'position': self._position,
|
||||||
|
'capital': self._capital,
|
||||||
|
'drawdown': drawdown,
|
||||||
|
'bars_in_trade': self._bars_in_trade,
|
||||||
|
'step': self._current_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
return obs, float(reward), terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
"""Affichage minimal de l'état courant."""
|
||||||
|
pos_str = {0: 'FLAT', 1: 'LONG', -1: 'SHORT'}.get(self._position, '?')
|
||||||
|
logger.debug(
|
||||||
|
f"Step={self._current_step} | Pos={pos_str} | "
|
||||||
|
f"Capital={self._capital:.2f} | PnL={self._total_pnl:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Observation
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_observation(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Construit le vecteur d'observation de 20 features normalisées.
|
||||||
|
|
||||||
|
Toutes les features de prix sont normalisées par l'ATR courant pour
|
||||||
|
rendre l'observation invariante à l'échelle du prix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray de forme (20,) avec dtype float32
|
||||||
|
"""
|
||||||
|
i = self._current_step
|
||||||
|
# Sécurité : ne pas dépasser les bornes
|
||||||
|
i = max(LOOKBACK, min(i, len(self.df) - 1))
|
||||||
|
|
||||||
|
close = float(self.df['close'].iloc[i])
|
||||||
|
open_ = float(self.df['open'].iloc[i])
|
||||||
|
high = float(self.df['high'].iloc[i])
|
||||||
|
low = float(self.df['low'].iloc[i])
|
||||||
|
vol = float(self.df['volume'].iloc[i])
|
||||||
|
|
||||||
|
atr = float(self._atr_series.iloc[i])
|
||||||
|
if atr <= 0:
|
||||||
|
atr = close * 0.001
|
||||||
|
|
||||||
|
# Close i-1 et i-5 pour le momentum
|
||||||
|
close_1 = float(self.df['close'].iloc[max(0, i - 1)])
|
||||||
|
close_5 = float(self.df['close'].iloc[max(0, i - 5)])
|
||||||
|
|
||||||
|
rsi = float(self._rsi.iloc[i]) if not np.isnan(self._rsi.iloc[i]) else 50.0
|
||||||
|
bb_upper = float(self._bb_upper.iloc[i]) if not np.isnan(self._bb_upper.iloc[i]) else close
|
||||||
|
bb_lower = float(self._bb_lower.iloc[i]) if not np.isnan(self._bb_lower.iloc[i]) else close
|
||||||
|
macd = float(self._macd.iloc[i]) if not np.isnan(self._macd.iloc[i]) else 0.0
|
||||||
|
macd_signal = float(self._macd_sig.iloc[i]) if not np.isnan(self._macd_sig.iloc[i]) else 0.0
|
||||||
|
vol_ma20 = float(self._vol_ma20.iloc[i])
|
||||||
|
ema9 = float(self._ema9.iloc[i])
|
||||||
|
ema21 = float(self._ema21.iloc[i])
|
||||||
|
ema50 = float(self._ema50.iloc[i])
|
||||||
|
ema200 = float(self._ema200.iloc[i])
|
||||||
|
|
||||||
|
vol_ratio = vol / max(vol_ma20, 1.0)
|
||||||
|
|
||||||
|
# PnL non-réalisé courant
|
||||||
|
if self._position != 0 and self._entry_price > 0:
|
||||||
|
unrealized = self._position * (close - self._entry_price)
|
||||||
|
else:
|
||||||
|
unrealized = 0.0
|
||||||
|
|
||||||
|
features = np.array([
|
||||||
|
# Prix relatifs normalisés par ATR
|
||||||
|
(close - open_) / atr, # 0 : corps de la bougie
|
||||||
|
(high - low) / atr, # 1 : amplitude bougie (volatilité relative)
|
||||||
|
(close - close_1) / atr, # 2 : momentum 1 barre
|
||||||
|
(close - close_5) / atr, # 3 : momentum 5 barres
|
||||||
|
# Indicateurs techniques normalisés
|
||||||
|
rsi / 100.0, # 4 : RSI [0..1]
|
||||||
|
(close - bb_upper) / atr, # 5 : position vs BB haute
|
||||||
|
(bb_upper - bb_lower) / atr, # 6 : largeur des bandes de Bollinger
|
||||||
|
macd / atr, # 7 : MACD normalisé
|
||||||
|
macd_signal / atr, # 8 : signal MACD normalisé
|
||||||
|
# Volume
|
||||||
|
np.clip(vol_ratio, 0.0, 5.0), # 9 : ratio volume / MA20 (plafonné à 5)
|
||||||
|
# Position et état du trade
|
||||||
|
float(self._position), # 10: position courante (-1, 0, 1)
|
||||||
|
unrealized / atr, # 11: PnL non-réalisé en ATR
|
||||||
|
min(self._bars_in_trade / 50.0, 1.0), # 12: durée du trade (normalisée)
|
||||||
|
# Temporel (si index datetime)
|
||||||
|
self._get_hour(i), # 13: heure normalisée [0..1]
|
||||||
|
self._get_dow(i), # 14: jour de semaine [0..1]
|
||||||
|
# Moyennes mobiles (distance close - EMA, normalisée par ATR)
|
||||||
|
(close - ema9) / atr, # 15: distance EMA9
|
||||||
|
(close - ema21) / atr, # 16: distance EMA21
|
||||||
|
(ema9 - ema21) / atr, # 17: croisement EMA9/21
|
||||||
|
(close - ema50) / atr, # 18: distance EMA50
|
||||||
|
(close - ema200) / atr, # 19: distance EMA200
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
# Clip pour éviter les valeurs extrêmes (protection robustesse)
|
||||||
|
features = np.clip(features, -10.0, 10.0)
|
||||||
|
# Remplacement des NaN résiduels
|
||||||
|
features = np.nan_to_num(features, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Gestion des positions
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _open_position(self, direction: int, price: float, atr: float) -> None:
|
||||||
|
"""
|
||||||
|
Ouvre une position dans la direction donnée.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
direction: 1=LONG, -1=SHORT
|
||||||
|
price: Prix d'entrée (close de la barre courante)
|
||||||
|
atr: ATR courant pour le calcul SL/TP
|
||||||
|
"""
|
||||||
|
self._position = direction
|
||||||
|
self._entry_price = price
|
||||||
|
self._entry_atr = atr
|
||||||
|
self._bars_in_trade = 0
|
||||||
|
|
||||||
|
def _close_position(self, price: float, atr: float) -> float:
|
||||||
|
"""
|
||||||
|
Ferme la position courante et calcule la récompense.
|
||||||
|
|
||||||
|
La récompense est le PnL en multiple d'ATR (normalisé et sans unité).
|
||||||
|
Cette normalisation rend la récompense comparable entre différents
|
||||||
|
symboles et périodes de volatilité.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price: Prix de sortie
|
||||||
|
atr: ATR courant (pour normalisation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Récompense (PnL normalisé par ATR)
|
||||||
|
"""
|
||||||
|
if self._position == 0 or self._entry_price <= 0:
|
||||||
|
self._position = 0
|
||||||
|
self._bars_in_trade = 0
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
raw_pnl = self._position * (price - self._entry_price)
|
||||||
|
|
||||||
|
# Normalisation par l'ATR d'entrée pour une récompense sans unité
|
||||||
|
entry_atr = max(self._entry_atr, price * 0.0001)
|
||||||
|
reward = raw_pnl / entry_atr
|
||||||
|
|
||||||
|
# Mise à jour du capital (simulation simplifiée avec 1 lot fixe)
|
||||||
|
self._capital += raw_pnl
|
||||||
|
self._total_pnl += raw_pnl
|
||||||
|
|
||||||
|
self._position = 0
|
||||||
|
self._entry_price = 0.0
|
||||||
|
self._bars_in_trade = 0
|
||||||
|
|
||||||
|
return float(reward)
|
||||||
|
|
||||||
|
def _check_sl_tp(self, current_close: float, current_atr: float) -> float:
|
||||||
|
"""
|
||||||
|
Vérifie si le SL ou TP est atteint pour la position courante.
|
||||||
|
|
||||||
|
Utilise l'ATR d'entrée pour les niveaux SL/TP (stabilité).
|
||||||
|
Si le SL ou TP est touché, la position est fermée.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_close: Prix de clôture de la barre courante
|
||||||
|
current_atr: ATR courant
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Récompense si fermeture forcée, 0.0 sinon
|
||||||
|
"""
|
||||||
|
if self._position == 0 or self._entry_price <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
entry_atr = max(self._entry_atr, self._entry_price * 0.0001)
|
||||||
|
sl_dist = self.sl_atr_mult * entry_atr
|
||||||
|
tp_dist = self.tp_atr_mult * entry_atr
|
||||||
|
|
||||||
|
if self._position == 1: # LONG
|
||||||
|
sl_level = self._entry_price - sl_dist
|
||||||
|
tp_level = self._entry_price + tp_dist
|
||||||
|
if current_close <= sl_level or current_close >= tp_level:
|
||||||
|
return self._close_position(current_close, current_atr)
|
||||||
|
|
||||||
|
elif self._position == -1: # SHORT
|
||||||
|
sl_level = self._entry_price + sl_dist
|
||||||
|
tp_level = self._entry_price - tp_dist
|
||||||
|
if current_close >= sl_level or current_close <= tp_level:
|
||||||
|
return self._close_position(current_close, current_atr)
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# Helpers
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _action_to_position(action: int) -> int:
|
||||||
|
"""Convertit l'action discrète en direction de position."""
|
||||||
|
return {ACTION_HOLD: 0, ACTION_LONG: 1, ACTION_SHORT: -1}.get(action, 0)
|
||||||
|
|
||||||
|
def _get_hour(self, i: int) -> float:
|
||||||
|
"""Retourne l'heure normalisée [0..1] depuis l'index, ou 0.5 si non datetime."""
|
||||||
|
try:
|
||||||
|
idx = self.df.index[i]
|
||||||
|
if hasattr(idx, 'hour'):
|
||||||
|
return float(idx.hour) / 24.0
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _get_dow(self, i: int) -> float:
|
||||||
|
"""Retourne le jour de semaine normalisé [0..1] depuis l'index, ou 0.5 si non datetime."""
|
||||||
|
try:
|
||||||
|
idx = self.df.index[i]
|
||||||
|
if hasattr(idx, 'dayofweek'):
|
||||||
|
return float(idx.dayofweek) / 5.0
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _compute_atr_series(self) -> pd.Series:
|
||||||
|
"""Calcule l'ATR(14) sur toute la série OHLCV."""
|
||||||
|
h = self.df['high']
|
||||||
|
l = self.df['low']
|
||||||
|
pc = self.df['close'].shift(1)
|
||||||
|
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||||
|
atr = tr.ewm(span=self.atr_period, adjust=False).mean()
|
||||||
|
# Remplir les premières valeurs NaN par la plage high-low
|
||||||
|
atr = atr.fillna(h - l)
|
||||||
|
return atr
|
||||||
|
|
||||||
|
def _compute_rsi(self, period: int = 14) -> pd.Series:
|
||||||
|
"""Calcule le RSI sur toute la série de closes."""
|
||||||
|
delta = self.df['close'].diff()
|
||||||
|
gain = delta.clip(lower=0).ewm(span=period, adjust=False).mean()
|
||||||
|
loss = (-delta.clip(upper=0)).ewm(span=period, adjust=False).mean()
|
||||||
|
rs = gain / loss.replace(0, np.nan)
|
||||||
|
rsi = 100.0 - (100.0 / (1.0 + rs))
|
||||||
|
return rsi.fillna(50.0)
|
||||||
|
|
||||||
|
def get_episode_stats(self) -> dict:
|
||||||
|
"""
|
||||||
|
Retourne les statistiques de l'épisode courant.
|
||||||
|
|
||||||
|
Utile pour le logging et l'évaluation du modèle PPO.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict avec total_pnl, n_steps, capital, Sharpe approximatif
|
||||||
|
"""
|
||||||
|
rewards = np.array(self._pnl_history)
|
||||||
|
if len(rewards) > 1 and rewards.std() > 0:
|
||||||
|
sharpe = float(rewards.mean() / rewards.std() * np.sqrt(252))
|
||||||
|
else:
|
||||||
|
sharpe = 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_pnl': self._total_pnl,
|
||||||
|
'capital': self._capital,
|
||||||
|
'return_pct': (self._capital - self.initial_capital) / self.initial_capital,
|
||||||
|
'n_steps': self._current_step,
|
||||||
|
'sharpe_approx': sharpe,
|
||||||
|
}
|
||||||
3
src/strategies/cnn_driven/__init__.py
Normal file
3
src/strategies/cnn_driven/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .cnn_strategy import CNNDrivenStrategy
|
||||||
|
|
||||||
|
__all__ = ['CNNDrivenStrategy']
|
||||||
249
src/strategies/cnn_driven/cnn_strategy.py
Normal file
249
src/strategies/cnn_driven/cnn_strategy.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""
|
||||||
|
CNN-Driven Strategy — Stratégie pilotée par un réseau convolutif 1D.
|
||||||
|
|
||||||
|
Cette stratégie utilise un CNN 1D entraîné sur des séquences OHLCV brutes
|
||||||
|
pour détecter des patterns visuels dans les bougies (double bottom, squeeze
|
||||||
|
Bollinger, alignements, etc.) sans features pré-calculées.
|
||||||
|
|
||||||
|
Fonctionnement :
|
||||||
|
1. Le modèle CNN est chargé depuis le disque (entraîné via POST /trading/train-cnn)
|
||||||
|
2. À chaque barre, la séquence OHLCV récente est passée au CNN
|
||||||
|
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
|
||||||
|
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||||
|
|
||||||
|
Intégration :
|
||||||
|
- Compatible avec StrategyEngine (même interface que ScalpingStrategy / MLDrivenStrategy)
|
||||||
|
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||||
|
- Le RiskManager applique les mêmes contrôles que pour les stratégies classiques
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
CNN_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CNN_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNDrivenStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
Stratégie de trading pilotée par un CNN 1D pré-entraîné.
|
||||||
|
|
||||||
|
Le modèle apprend directement les patterns visuels des bougies :
|
||||||
|
- Double bottom / double top
|
||||||
|
- Squeeze Bollinger + expansion
|
||||||
|
- Alignements de moyennes mobiles
|
||||||
|
- Patterns chandeliers complexes
|
||||||
|
- Structures de prix multi-barres
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||||
|
|
||||||
|
Config keys supplémentaires (optionnelles) :
|
||||||
|
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||||
|
seq_len: Longueur de séquence d'entrée (défaut: 64)
|
||||||
|
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRATEGY_NAME = 'cnn_driven'
|
||||||
|
|
||||||
|
def __init__(self, config: Dict):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.symbol = config.get('symbol', 'EURUSD')
|
||||||
|
self.min_confidence = config.get('min_confidence', 0.55)
|
||||||
|
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||||
|
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||||
|
self.seq_len = config.get('seq_len', 64)
|
||||||
|
|
||||||
|
self.cnn_model: Optional['CNNStrategyModel'] = None
|
||||||
|
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
logger.warning("CNN non disponible (PyTorch requis)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Tentative de chargement automatique du modèle existant
|
||||||
|
if config.get('auto_load', True):
|
||||||
|
self._try_load_model()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Interface BaseStrategy
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||||
|
"""
|
||||||
|
Génère un signal de trading via le modèle CNN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal si le modèle est confiant, None sinon
|
||||||
|
"""
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
logger.debug("CNN Strategy : PyTorch non disponible, aucun signal")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.cnn_model is None or not self.cnn_model.is_trained:
|
||||||
|
logger.debug("CNN Strategy : modèle non chargé, aucun signal")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(market_data) < self.seq_len:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.cnn_model.predict(market_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"CNN Strategy predict error : {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not result.get('tradeable', False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||||
|
confidence = result['confidence']
|
||||||
|
|
||||||
|
# Prix et ATR pour SL/TP
|
||||||
|
last_close = float(market_data['close'].iloc[-1])
|
||||||
|
atr = self._compute_atr(market_data)
|
||||||
|
if atr <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if signal_dir == 1:
|
||||||
|
direction = 'LONG'
|
||||||
|
stop_loss = last_close - self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close + self.tp_atr_mult * atr
|
||||||
|
elif signal_dir == -1:
|
||||||
|
direction = 'SHORT'
|
||||||
|
stop_loss = last_close + self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close - self.tp_atr_mult * atr
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal = Signal(
|
||||||
|
symbol = self.symbol,
|
||||||
|
direction = direction,
|
||||||
|
entry_price = last_close,
|
||||||
|
stop_loss = stop_loss,
|
||||||
|
take_profit = take_profit,
|
||||||
|
confidence = confidence,
|
||||||
|
timestamp = datetime.now(timezone.utc),
|
||||||
|
strategy = self.STRATEGY_NAME,
|
||||||
|
metadata = {
|
||||||
|
'probas': result.get('probas', {}),
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'atr': atr,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"CNN Signal : {direction} {self.symbol} | "
|
||||||
|
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||||
|
f"confidence={confidence:.2%}"
|
||||||
|
)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Retourne les données telles quelles — le CNN travaille sur les séquences brutes."""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def update_params(self, params: Dict) -> None:
|
||||||
|
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||||
|
if 'min_confidence' in params:
|
||||||
|
self.min_confidence = params['min_confidence']
|
||||||
|
if self.cnn_model:
|
||||||
|
self.cnn_model.min_confidence = params['min_confidence']
|
||||||
|
if 'tp_atr_mult' in params:
|
||||||
|
self.tp_atr_mult = params['tp_atr_mult']
|
||||||
|
if 'sl_atr_mult' in params:
|
||||||
|
self.sl_atr_mult = params['sl_atr_mult']
|
||||||
|
if 'seq_len' in params:
|
||||||
|
self.seq_len = params['seq_len']
|
||||||
|
logger.info(f"CNN Strategy params mis à jour : {params}")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Gestion du modèle
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Charge un modèle CNN depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (défaut: self.symbol)
|
||||||
|
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True si chargement réussi
|
||||||
|
"""
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
logger.warning("CNN non disponible (PyTorch requis)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
sym = symbol or self.symbol
|
||||||
|
tf = timeframe or self.config.timeframe
|
||||||
|
try:
|
||||||
|
self.cnn_model = CNNStrategyModel.load(sym, tf)
|
||||||
|
logger.info(f"Modèle CNN chargé : {sym}/{tf}")
|
||||||
|
return True
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info(f"Aucun modèle CNN trouvé pour {sym}/{tf}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Erreur chargement modèle CNN : {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def attach_model(self, model: 'CNNStrategyModel') -> None:
|
||||||
|
"""Attache directement un modèle CNN (après entraînement via API)."""
|
||||||
|
self.cnn_model = model
|
||||||
|
self.symbol = model.symbol
|
||||||
|
logger.info(f"Modèle CNN attaché : {model.symbol}/{model.timeframe}")
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""Retourne True si le modèle CNN est chargé et entraîné."""
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
return False
|
||||||
|
return self.cnn_model is not None and self.cnn_model.is_trained
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict:
|
||||||
|
"""Retourne les métadonnées du modèle CNN actif."""
|
||||||
|
if not CNN_AVAILABLE:
|
||||||
|
return {'status': 'PyTorch non disponible'}
|
||||||
|
if not self.is_ready():
|
||||||
|
return {'status': 'non entraîné'}
|
||||||
|
meta = self.cnn_model.metadata.copy()
|
||||||
|
meta['is_ready'] = True
|
||||||
|
meta['seq_len'] = self.seq_len
|
||||||
|
return meta
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def _try_load_model(self) -> None:
|
||||||
|
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||||
|
try:
|
||||||
|
self.load_model()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||||
|
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||||
|
if len(df) < period + 1:
|
||||||
|
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
|
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||||
|
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||||
|
atr = tr.rolling(period).mean().iloc[-1]
|
||||||
|
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
3
src/strategies/cnn_image_driven/__init__.py
Normal file
3
src/strategies/cnn_image_driven/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from src.strategies.cnn_image_driven.cnn_image_strategy import CNNImageDrivenStrategy
|
||||||
|
|
||||||
|
__all__ = ['CNNImageDrivenStrategy']
|
||||||
258
src/strategies/cnn_image_driven/cnn_image_strategy.py
Normal file
258
src/strategies/cnn_image_driven/cnn_image_strategy.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""
|
||||||
|
CNN Image-Driven Strategy — Stratégie pilotée par un CNN à vision par ordinateur.
|
||||||
|
|
||||||
|
Contrairement au CNN 1D qui analyse des séquences numériques brutes, cette
|
||||||
|
stratégie convertit les bougies OHLCV en images de graphiques (rendu mplfinance
|
||||||
|
128×128 RGB), puis utilise un réseau Conv2D pour reconnaître les patterns
|
||||||
|
visuels — exactement comme un trader devant TradingView.
|
||||||
|
|
||||||
|
Patterns détectés sans feature engineering explicite :
|
||||||
|
- Bougies marteau, étoile filante, doji
|
||||||
|
- Double top, double bottom
|
||||||
|
- Rebonds sur supports/résistances visuels
|
||||||
|
- Momentum : grande bougie pleine après consolidation
|
||||||
|
- Divergences prix/volume visibles
|
||||||
|
|
||||||
|
Fonctionnement :
|
||||||
|
1. Le modèle CNN-Image est chargé depuis le disque (entraîné via POST /trading/train-cnn-image)
|
||||||
|
2. À chaque barre, les seq_len dernières bougies sont rendues en image 128×128
|
||||||
|
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
|
||||||
|
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||||
|
|
||||||
|
Intégration :
|
||||||
|
- Compatible avec StrategyEngine (même interface que CNNDrivenStrategy)
|
||||||
|
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||||
|
- Requiert PyTorch + mplfinance + Pillow
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.ml.cnn_image import CNNImageStrategyModel
|
||||||
|
CNN_IMAGE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CNN_IMAGE_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNImageDrivenStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
Stratégie de trading pilotée par un CNN Conv2D pré-entraîné sur images de graphiques.
|
||||||
|
|
||||||
|
Le modèle apprend les patterns visuels des chandeliers directement depuis
|
||||||
|
des images 128×128 RGB sans features pré-calculées :
|
||||||
|
- Double bottom / double top
|
||||||
|
- Configurations chandeliers (marteau, doji, étoile filante)
|
||||||
|
- Rebonds sur zones de support/résistance visibles
|
||||||
|
- Structures de momentum multi-barres
|
||||||
|
- Divergences volume/prix visuelles
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||||
|
|
||||||
|
Config keys supplémentaires (optionnelles) :
|
||||||
|
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||||
|
seq_len: Nombre de bougies par image (défaut: 64)
|
||||||
|
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRATEGY_NAME = 'cnn_image_driven'
|
||||||
|
|
||||||
|
def __init__(self, config: Dict):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.symbol = config.get('symbol', 'EURUSD')
|
||||||
|
self.min_confidence = config.get('min_confidence', 0.55)
|
||||||
|
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||||
|
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||||
|
self.seq_len = config.get('seq_len', 64)
|
||||||
|
|
||||||
|
self.cnn_image_model: Optional['CNNImageStrategyModel'] = None
|
||||||
|
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Tentative de chargement automatique du modèle existant
|
||||||
|
if config.get('auto_load', True):
|
||||||
|
self._try_load_model()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Interface BaseStrategy
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||||
|
"""
|
||||||
|
Génère un signal de trading via le modèle CNN Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal si le modèle est confiant, None sinon
|
||||||
|
"""
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
logger.debug("CNN Image Strategy : PyTorch/mplfinance non disponible, aucun signal")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.cnn_image_model is None or not self.cnn_image_model.is_trained:
|
||||||
|
logger.debug("CNN Image Strategy : modèle non chargé, aucun signal")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(market_data) < self.seq_len:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.cnn_image_model.predict(market_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"CNN Image Strategy predict error : {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not result.get('tradeable', False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||||
|
confidence = result['confidence']
|
||||||
|
|
||||||
|
# Prix et ATR pour SL/TP
|
||||||
|
last_close = float(market_data['close'].iloc[-1])
|
||||||
|
atr = self._compute_atr(market_data)
|
||||||
|
if atr <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if signal_dir == 1:
|
||||||
|
direction = 'LONG'
|
||||||
|
stop_loss = last_close - self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close + self.tp_atr_mult * atr
|
||||||
|
elif signal_dir == -1:
|
||||||
|
direction = 'SHORT'
|
||||||
|
stop_loss = last_close + self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close - self.tp_atr_mult * atr
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal = Signal(
|
||||||
|
symbol = self.symbol,
|
||||||
|
direction = direction,
|
||||||
|
entry_price = last_close,
|
||||||
|
stop_loss = stop_loss,
|
||||||
|
take_profit = take_profit,
|
||||||
|
confidence = confidence,
|
||||||
|
timestamp = datetime.now(timezone.utc),
|
||||||
|
strategy = self.STRATEGY_NAME,
|
||||||
|
metadata = {
|
||||||
|
'probas': result.get('probas', {}),
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'atr': atr,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"CNN Image Signal : {direction} {self.symbol} | "
|
||||||
|
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||||
|
f"confidence={confidence:.2%}"
|
||||||
|
)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Retourne les données telles quelles — le CNN Image travaille sur des rendus d'images."""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def update_params(self, params: Dict) -> None:
|
||||||
|
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||||
|
if 'min_confidence' in params:
|
||||||
|
self.min_confidence = params['min_confidence']
|
||||||
|
if self.cnn_image_model:
|
||||||
|
self.cnn_image_model.min_confidence = params['min_confidence']
|
||||||
|
if 'tp_atr_mult' in params:
|
||||||
|
self.tp_atr_mult = params['tp_atr_mult']
|
||||||
|
if 'sl_atr_mult' in params:
|
||||||
|
self.sl_atr_mult = params['sl_atr_mult']
|
||||||
|
if 'seq_len' in params:
|
||||||
|
self.seq_len = params['seq_len']
|
||||||
|
logger.info(f"CNN Image Strategy params mis à jour : {params}")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Gestion du modèle
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Charge un modèle CNN Image depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (défaut: self.symbol)
|
||||||
|
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True si chargement réussi
|
||||||
|
"""
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
sym = symbol or self.symbol
|
||||||
|
tf = timeframe or self.config.timeframe
|
||||||
|
try:
|
||||||
|
self.cnn_image_model = CNNImageStrategyModel.load(sym, tf)
|
||||||
|
logger.info(f"Modèle CNN Image chargé : {sym}/{tf}")
|
||||||
|
return True
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info(f"Aucun modèle CNN Image trouvé pour {sym}/{tf}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Erreur chargement modèle CNN Image : {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def attach_model(self, model: 'CNNImageStrategyModel') -> None:
|
||||||
|
"""Attache directement un modèle CNN Image (après entraînement via API)."""
|
||||||
|
self.cnn_image_model = model
|
||||||
|
self.symbol = model.symbol
|
||||||
|
logger.info(f"Modèle CNN Image attaché : {model.symbol}/{model.timeframe}")
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""Retourne True si le modèle CNN Image est chargé et entraîné."""
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
return False
|
||||||
|
return self.cnn_image_model is not None and self.cnn_image_model.is_trained
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict:
|
||||||
|
"""Retourne les métadonnées du modèle CNN Image actif."""
|
||||||
|
if not CNN_IMAGE_AVAILABLE:
|
||||||
|
return {'status': 'PyTorch/mplfinance non disponible'}
|
||||||
|
if not self.is_ready():
|
||||||
|
return {'status': 'non entraîné'}
|
||||||
|
meta = self.cnn_image_model.metadata.copy()
|
||||||
|
meta['is_ready'] = True
|
||||||
|
meta['seq_len'] = self.seq_len
|
||||||
|
return meta
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
def _try_load_model(self) -> None:
|
||||||
|
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||||
|
try:
|
||||||
|
self.load_model()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||||
|
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||||
|
if len(df) < period + 1:
|
||||||
|
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
|
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||||
|
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||||
|
atr = tr.rolling(period).mean().iloc[-1]
|
||||||
|
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
3
src/strategies/ensemble/__init__.py
Normal file
3
src/strategies/ensemble/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .ensemble_strategy import EnsembleStrategy
|
||||||
|
|
||||||
|
__all__ = ['EnsembleStrategy']
|
||||||
190
src/strategies/ensemble/ensemble_strategy.py
Normal file
190
src/strategies/ensemble/ensemble_strategy.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""
|
||||||
|
Ensemble Strategy — Stratégie combinant XGBoost + CNN (+ RL futur).
|
||||||
|
|
||||||
|
Cette stratégie utilise l'EnsembleModel pour agréger les signaux de plusieurs
|
||||||
|
modèles ML. Un signal n'est émis que si les modèles sont en accord et que
|
||||||
|
le score pondéré dépasse le seuil configuré.
|
||||||
|
|
||||||
|
Config keys :
|
||||||
|
weights: dict poids par modèle (défaut: XGB=0.40, CNN=0.60)
|
||||||
|
min_confidence: seuil score pondéré (défaut: 0.60)
|
||||||
|
require_agreement: exiger accord entre modèles (défaut: True)
|
||||||
|
tp_atr_mult: TP en multiples d'ATR (défaut: 2.0)
|
||||||
|
sl_atr_mult: SL en multiples d'ATR (défaut: 1.0)
|
||||||
|
auto_load: charger modèles existants au démarrage (défaut: True)
|
||||||
|
symbol: paire tradée (défaut: 'EURUSD')
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.strategies.base_strategy import BaseStrategy, Signal
|
||||||
|
from src.ml.ensemble.ensemble_model import EnsembleModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EnsembleStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
Stratégie de trading combinant plusieurs modèles ML via EnsembleModel.
|
||||||
|
|
||||||
|
Nécessite au minimum 2 modèles entraînés et attachés pour émettre
|
||||||
|
des signaux. Les SL/TP sont calculés à partir de l'ATR.
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRATEGY_NAME = 'ensemble'
|
||||||
|
|
||||||
|
def __init__(self, config: Dict):
|
||||||
|
# Forcer le nom de la stratégie
|
||||||
|
config.setdefault('name', self.STRATEGY_NAME)
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.symbol = config.get('symbol', 'EURUSD')
|
||||||
|
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||||
|
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||||
|
|
||||||
|
self.ensemble = EnsembleModel(
|
||||||
|
weights=config.get('weights'),
|
||||||
|
min_confidence=config.get('min_confidence', 0.60),
|
||||||
|
require_agreement=config.get('require_agreement', True),
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.get('auto_load', True):
|
||||||
|
self._try_load_models()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Interface BaseStrategy
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||||
|
"""
|
||||||
|
Génère un signal via l'ensemble de modèles ML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: DataFrame OHLCV (minimum 50 barres)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal si l'ensemble est confiant et en accord, None sinon
|
||||||
|
"""
|
||||||
|
if not self.ensemble.is_ready():
|
||||||
|
logger.debug("Ensemble Strategy : ensemble non prêt (< 2 modèles)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(market_data) < 50:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.ensemble.predict(market_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Ensemble Strategy predict error : {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not result.get('tradeable', False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||||
|
confidence = result['confidence']
|
||||||
|
|
||||||
|
# Prix et ATR pour SL/TP
|
||||||
|
last_close = float(market_data['close'].iloc[-1])
|
||||||
|
atr = self._compute_atr(market_data)
|
||||||
|
if atr <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if signal_dir == 1:
|
||||||
|
direction = 'LONG'
|
||||||
|
stop_loss = last_close - self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close + self.tp_atr_mult * atr
|
||||||
|
elif signal_dir == -1:
|
||||||
|
direction = 'SHORT'
|
||||||
|
stop_loss = last_close + self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close - self.tp_atr_mult * atr
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal = Signal(
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction=direction,
|
||||||
|
entry_price=last_close,
|
||||||
|
stop_loss=stop_loss,
|
||||||
|
take_profit=take_profit,
|
||||||
|
confidence=confidence,
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
strategy=self.STRATEGY_NAME,
|
||||||
|
metadata={
|
||||||
|
'ensemble_agreement': result.get('agreement', False),
|
||||||
|
'components': result.get('components', {}),
|
||||||
|
'atr': atr,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Ensemble Signal : {direction} {self.symbol} | "
|
||||||
|
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||||
|
f"confidence={confidence:.2%} | accord={result.get('agreement')}"
|
||||||
|
)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Retourne les données telles quelles — les features sont dans predict()."""
|
||||||
|
return data
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Gestion des modèles
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def attach_xgboost(self, model) -> None:
|
||||||
|
"""Attache un modèle XGBoost à l'ensemble."""
|
||||||
|
self.ensemble.attach_xgboost(model)
|
||||||
|
|
||||||
|
def attach_cnn(self, model) -> None:
|
||||||
|
"""Attache un modèle CNN à l'ensemble."""
|
||||||
|
self.ensemble.attach_cnn(model)
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""True si l'ensemble a au moins 2 modèles entraînés."""
|
||||||
|
return self.ensemble.is_ready()
|
||||||
|
|
||||||
|
def get_status(self) -> Dict:
|
||||||
|
"""Retourne le statut de l'ensemble."""
|
||||||
|
return self.ensemble.get_status()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Chargement automatique
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _try_load_models(self) -> None:
|
||||||
|
"""Tente de charger les modèles existants au démarrage."""
|
||||||
|
# XGBoost via MLStrategyModel
|
||||||
|
try:
|
||||||
|
from src.ml.ml_strategy_model import MLStrategyModel
|
||||||
|
xgb = MLStrategyModel.load(self.symbol, self.config.timeframe, 'xgboost')
|
||||||
|
self.ensemble.attach_xgboost(xgb)
|
||||||
|
logger.info(f"Ensemble : modèle XGBoost chargé pour {self.symbol}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Ensemble : pas de modèle XGBoost pour {self.symbol}")
|
||||||
|
|
||||||
|
# CNN via CNNStrategyModel (peut ne pas encore exister)
|
||||||
|
try:
|
||||||
|
from src.ml.cnn import CNNStrategyModel
|
||||||
|
cnn = CNNStrategyModel.load(self.symbol, self.config.timeframe)
|
||||||
|
self.ensemble.attach_cnn(cnn)
|
||||||
|
logger.info(f"Ensemble : modèle CNN chargé pour {self.symbol}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug(f"Ensemble : pas de modèle CNN pour {self.symbol}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||||
|
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||||
|
if len(df) < period + 1:
|
||||||
|
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
|
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||||
|
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||||
|
atr = tr.rolling(period).mean().iloc[-1]
|
||||||
|
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
10
src/strategies/rl_driven/__init__.py
Normal file
10
src/strategies/rl_driven/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Module RL-Driven Strategy — Stratégie de trading pilotée par un agent PPO.
|
||||||
|
|
||||||
|
L'agent PPO apprend à trader par renforcement : il maximise directement
|
||||||
|
le PnL sans supervision explicite (pas de labels LONG/SHORT/NEUTRAL).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.strategies.rl_driven.rl_strategy import RLDrivenStrategy, RL_AVAILABLE
|
||||||
|
|
||||||
|
__all__ = ['RLDrivenStrategy', 'RL_AVAILABLE']
|
||||||
259
src/strategies/rl_driven/rl_strategy.py
Normal file
259
src/strategies/rl_driven/rl_strategy.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
RL-Driven Strategy — Stratégie pilotée par un agent de Reinforcement Learning (PPO).
|
||||||
|
|
||||||
|
Contrairement aux stratégies supervisées (XGBoost, CNN) qui apprennent depuis des
|
||||||
|
labels pré-calculés, cette stratégie utilise un agent PPO (Proximal Policy
|
||||||
|
Optimization) entraîné via interaction directe avec un environnement de trading
|
||||||
|
(TradingEnv). L'agent apprend à maximiser la récompense cumulative (profit ajusté
|
||||||
|
par le risque) sans avoir besoin de labels explicites.
|
||||||
|
|
||||||
|
Fonctionnement :
|
||||||
|
1. Le modèle RLStrategyModel est chargé depuis le disque
|
||||||
|
(entraîné via POST /trading/train-rl)
|
||||||
|
2. À chaque barre, les seq_len dernières bougies sont fournies à l'agent
|
||||||
|
3. L'agent PPO retourne une action : LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||||
|
avec un score de confiance basé sur les probabilités de la politique
|
||||||
|
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||||
|
|
||||||
|
Intégration :
|
||||||
|
- Compatible avec StrategyEngine (même interface que CNNImageDrivenStrategy)
|
||||||
|
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||||
|
- Requiert PyTorch (CPU ou CUDA)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||||
|
|
||||||
|
# Import conditionnel — RL_AVAILABLE reste False si PyTorch ou stable-baselines3
|
||||||
|
# ne sont pas installés dans le container
|
||||||
|
try:
|
||||||
|
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||||
|
RL_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
RLStrategyModel = None
|
||||||
|
RL_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RLDrivenStrategy(BaseStrategy):
|
||||||
|
"""
|
||||||
|
Stratégie de trading pilotée par un agent PPO (Reinforcement Learning).
|
||||||
|
|
||||||
|
L'agent apprend à maximiser le profit cumulatif en interagissant avec un
|
||||||
|
environnement de trading simulé (TradingEnv). Aucun label supervisé n'est
|
||||||
|
nécessaire : la récompense est définie par le P&L réalisé et les pénalités
|
||||||
|
de risque (drawdown, over-trading).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||||
|
|
||||||
|
Config keys supplémentaires (optionnelles) :
|
||||||
|
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||||
|
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||||
|
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||||
|
seq_len: Fenêtre d'observation RL (barres) (défaut: 20)
|
||||||
|
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRATEGY_NAME = 'rl_driven'
|
||||||
|
|
||||||
|
def __init__(self, config: Dict):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.symbol = config.get('symbol', 'EURUSD')
|
||||||
|
self.min_confidence = config.get('min_confidence', 0.55)
|
||||||
|
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||||
|
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||||
|
self.seq_len = config.get('seq_len', 20)
|
||||||
|
|
||||||
|
self.rl_model: Optional['RLStrategyModel'] = None
|
||||||
|
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Tentative de chargement automatique du modèle existant
|
||||||
|
if config.get('auto_load', True):
|
||||||
|
self._try_load_model()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Interface BaseStrategy
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||||
|
"""
|
||||||
|
Génère un signal de trading via l'agent PPO.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal si l'agent est confiant, None sinon
|
||||||
|
"""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
logger.debug("RL Strategy : PyTorch / stable-baselines3 non disponible, aucun signal")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.rl_model is None or not self.rl_model.is_trained:
|
||||||
|
logger.debug("RL Strategy : modèle non chargé, aucun signal")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(market_data) < self.seq_len:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.rl_model.predict(market_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"RL Strategy predict error : {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not result.get('tradeable', False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||||
|
confidence = result['confidence']
|
||||||
|
|
||||||
|
# Prix et ATR pour le calcul des niveaux SL/TP
|
||||||
|
last_close = float(market_data['close'].iloc[-1])
|
||||||
|
atr = self._compute_atr(market_data)
|
||||||
|
if atr <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if signal_dir == 1:
|
||||||
|
direction = 'LONG'
|
||||||
|
stop_loss = last_close - self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close + self.tp_atr_mult * atr
|
||||||
|
elif signal_dir == -1:
|
||||||
|
direction = 'SHORT'
|
||||||
|
stop_loss = last_close + self.sl_atr_mult * atr
|
||||||
|
take_profit = last_close - self.tp_atr_mult * atr
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
signal = Signal(
|
||||||
|
symbol = self.symbol,
|
||||||
|
direction = direction,
|
||||||
|
entry_price = last_close,
|
||||||
|
stop_loss = stop_loss,
|
||||||
|
take_profit = take_profit,
|
||||||
|
confidence = confidence,
|
||||||
|
timestamp = datetime.now(timezone.utc),
|
||||||
|
strategy = self.STRATEGY_NAME,
|
||||||
|
metadata = {
|
||||||
|
'probas': result.get('probas', {}),
|
||||||
|
'seq_len': self.seq_len,
|
||||||
|
'atr': atr,
|
||||||
|
'tp_atr_mult': self.tp_atr_mult,
|
||||||
|
'sl_atr_mult': self.sl_atr_mult,
|
||||||
|
'avg_reward': result.get('avg_reward'),
|
||||||
|
'total_timesteps': result.get('total_timesteps'),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"RL Signal : {direction} {self.symbol} | "
|
||||||
|
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||||
|
f"confidence={confidence:.2%}"
|
||||||
|
)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""Retourne les données telles quelles — l'agent RL travaille sur les OHLCV bruts."""
|
||||||
|
return data
|
||||||
|
|
||||||
|
def update_params(self, params: Dict) -> None:
|
||||||
|
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||||
|
if 'min_confidence' in params:
|
||||||
|
self.min_confidence = params['min_confidence']
|
||||||
|
if self.rl_model:
|
||||||
|
self.rl_model.min_confidence = params['min_confidence']
|
||||||
|
if 'tp_atr_mult' in params:
|
||||||
|
self.tp_atr_mult = params['tp_atr_mult']
|
||||||
|
if 'sl_atr_mult' in params:
|
||||||
|
self.sl_atr_mult = params['sl_atr_mult']
|
||||||
|
if 'seq_len' in params:
|
||||||
|
self.seq_len = params['seq_len']
|
||||||
|
logger.info(f"RL Strategy params mis à jour : {params}")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Gestion du modèle
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Charge un modèle RL depuis le disque.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Paire (défaut: self.symbol)
|
||||||
|
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True si chargement réussi
|
||||||
|
"""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
sym = symbol or self.symbol
|
||||||
|
tf = timeframe or self.config.timeframe
|
||||||
|
try:
|
||||||
|
self.rl_model = RLStrategyModel.load(sym, tf)
|
||||||
|
logger.info(f"Modèle RL chargé : {sym}/{tf}")
|
||||||
|
return True
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.info(f"Aucun modèle RL trouvé pour {sym}/{tf}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Erreur chargement modèle RL : {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def attach_model(self, model: 'RLStrategyModel') -> None:
|
||||||
|
"""Attache directement un modèle RL (après entraînement via API)."""
|
||||||
|
self.rl_model = model
|
||||||
|
self.symbol = model.symbol
|
||||||
|
logger.info(f"Modèle RL attaché : {model.symbol}/{model.timeframe}")
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""Retourne True si le modèle RL est chargé et entraîné."""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
return False
|
||||||
|
return self.rl_model is not None and self.rl_model.is_trained
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict:
|
||||||
|
"""Retourne les métadonnées du modèle RL actif."""
|
||||||
|
if not RL_AVAILABLE:
|
||||||
|
return {'status': 'PyTorch / stable-baselines3 non disponible'}
|
||||||
|
if not self.is_ready():
|
||||||
|
return {'status': 'non entraîné'}
|
||||||
|
meta = self.rl_model.metadata.copy()
|
||||||
|
meta['is_ready'] = True
|
||||||
|
meta['seq_len'] = self.seq_len
|
||||||
|
return meta
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _try_load_model(self) -> None:
|
||||||
|
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||||
|
try:
|
||||||
|
self.load_model()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||||
|
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||||
|
if len(df) < period + 1:
|
||||||
|
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
|
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||||
|
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||||
|
atr = tr.rolling(period).mean().iloc[-1]
|
||||||
|
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||||
Reference in New Issue
Block a user