Compare commits
9 Commits
8f3b026f82
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff8d58e1aa | ||
|
|
80e1308a1e | ||
|
|
e9d4c440d9 | ||
|
|
6fd68af47a | ||
|
|
7af7248b4d | ||
|
|
d245d7d8f4 | ||
|
|
acc3338213 | ||
|
|
8732acf3d0 | ||
|
|
daea333555 |
@@ -20,3 +20,16 @@ prometheus-client==0.19.0
|
||||
|
||||
# Notifications
|
||||
python-telegram-bot==20.7
|
||||
|
||||
# ML — requis pour MLDrivenStrategy (entraînement et prédiction dans l'API)
|
||||
scikit-learn==1.3.2
|
||||
xgboost==2.0.3
|
||||
lightgbm==4.1.0
|
||||
joblib>=1.3.0
|
||||
|
||||
# ML — Deep Learning (CNN pour patterns chandeliers)
|
||||
torch>=2.0.0
|
||||
|
||||
# Visualisation — graphiques financiers et traitement d'images
|
||||
mplfinance>=0.12.10b0
|
||||
Pillow>=10.0.0
|
||||
|
||||
227
docs/CNN_ENSEMBLE_PLAN.md
Normal file
227
docs/CNN_ENSEMBLE_PLAN.md
Normal file
@@ -0,0 +1,227 @@
|
||||
# Plan : CNN + Ensemble Multi-Signal
|
||||
|
||||
**Créé** : 2026-03-08
|
||||
**Statut** : Phase 4c — En cours de développement
|
||||
|
||||
---
|
||||
|
||||
## Concept
|
||||
|
||||
Coupler trois modèles complémentaires pour produire un signal de trading robuste :
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Signal final pondéré │
|
||||
│ trade si score > seuil (ex: 0.60) │
|
||||
│ │
|
||||
│ score = w1×XGB_conf + w2×CNN_conf (+ w3×RL_conf) │
|
||||
│ (ex: 0.40 × 0.72 + 0.60 × 0.68 = 0.70) │
|
||||
└──────────┬──────────────────────┬───────────────────┘
|
||||
│ │
|
||||
┌──────▼──────┐ ┌──────▼──────┐
|
||||
│ XGBoost │ │ CNN │
|
||||
│ (Phase 4b) │ │ (Phase 4c) │
|
||||
│ │ │ │
|
||||
│ 50 features│ │ Fenêtre │
|
||||
│ TA calculés│ │ 64 bougies │
|
||||
│ (RSI, MACD,│ │ OHLCV → │
|
||||
│ pivots...)│ │ séquence │
|
||||
│ │ │ 1D CNN │
|
||||
│ "indicat." │ │ "visuel" │
|
||||
└─────────────┘ └─────────────┘
|
||||
│ (Phase 4d)
|
||||
┌──────▼──────┐
|
||||
│ RL │
|
||||
│ (futur) │
|
||||
│ Récompense │
|
||||
│ = PnL réel │
|
||||
│ Apprend par│
|
||||
│ essai/erreur│
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
### Complémentarité des modèles
|
||||
|
||||
| Composant | Ce qu'il voit | Ce qu'il détecte | Limite |
|
||||
|---|---|---|---|
|
||||
| XGBoost | Indicateurs calculés | Combinaisons règles TA | Dépend des features choisies |
|
||||
| CNN | Séquence brute OHLCV | Patterns visuels (double bottom, squeeze, H&S...) | Besoin de beaucoup de data |
|
||||
| RL | Historique de ses trades | Ce qui rapporte sans règles | Instable, lent à converger |
|
||||
|
||||
**Principe de l'ensemble** : un signal qui passe deux (ou trois) filtres indépendants a une probabilité nettement plus élevée d'être correct qu'un signal issu d'un seul modèle.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4c — CNN (priorité immédiate)
|
||||
|
||||
### Architecture CNN
|
||||
|
||||
```
|
||||
Entrée : dernières 64 bougies OHLCV
|
||||
→ normalisation z-score par fenêtre
|
||||
→ shape : (batch, 64, 5) # seq_len=64, features=5 (OHLCV)
|
||||
|
||||
Conv1D(filters=32, kernel=3) → ReLU → MaxPool(2)
|
||||
Conv1D(filters=64, kernel=3) → ReLU → MaxPool(2)
|
||||
Conv1D(filters=128, kernel=3) → ReLU → GlobalAvgPool
|
||||
Dense(128) → Dropout(0.3)
|
||||
Dense(3) → Softmax # [LONG, SHORT, NEUTRAL]
|
||||
```
|
||||
|
||||
Pas de conversion en image 2D — la CNN 1D sur séquences OHLCV est plus naturelle
|
||||
et plus performante pour les séries temporelles financières.
|
||||
|
||||
### Dépendance PyTorch
|
||||
|
||||
Le CNN requiert PyTorch. Il faut :
|
||||
1. Ajouter `torch==2.1.0+cpu` dans `docker/requirements/api.txt`
|
||||
*(CPU only — pas de GPU requis pour l'inférence en trading)*
|
||||
2. Rebuilder l'image : `docker compose build --no-cache trading-api`
|
||||
|
||||
### Fichiers à créer — CNN
|
||||
|
||||
```
|
||||
src/ml/cnn/
|
||||
├── __init__.py
|
||||
├── candlestick_encoder.py # Normalisation + préparation séquences OHLCV
|
||||
├── cnn_model.py # Architecture PyTorch (1D CNN)
|
||||
└── cnn_strategy_model.py # Wrapper train/predict/save/load (comme MLStrategyModel)
|
||||
|
||||
src/strategies/cnn_driven/
|
||||
├── __init__.py
|
||||
└── cnn_strategy.py # CNNDrivenStrategy (hérite BaseStrategy)
|
||||
|
||||
models/cnn_strategy/ # Sauvegardes .pt + _meta.json
|
||||
```
|
||||
|
||||
### Routes API à ajouter — CNN
|
||||
|
||||
| Méthode | Route | Description |
|
||||
|---|---|---|
|
||||
| POST | `/trading/train-cnn` | Lance entraînement CNN async |
|
||||
| GET | `/trading/train-cnn/{job_id}` | Statut + métriques |
|
||||
| GET | `/trading/cnn-models` | Liste modèles disponibles |
|
||||
|
||||
### Métriques cibles CNN
|
||||
|
||||
- `wf_accuracy > 0.52` (plus difficile que XGBoost — séquences brutes)
|
||||
- Distribution LONG/SHORT équilibrée (même méthode labels que XGBoost)
|
||||
- `wf_precision > 0.48` sur signaux directionnels
|
||||
|
||||
---
|
||||
|
||||
## Phase 4c — Ensemble XGBoost + CNN
|
||||
|
||||
### Logique de combinaison
|
||||
|
||||
```python
|
||||
# Les deux modèles prédisent indépendamment
|
||||
xgb_result = xgb_model.predict(df) # {'signal': 1, 'confidence': 0.72}
|
||||
cnn_result = cnn_model.predict(df) # {'signal': 1, 'confidence': 0.68}
|
||||
|
||||
# Score pondéré (seulement si même direction)
|
||||
if xgb_result['signal'] == cnn_result['signal']:
|
||||
score = w_xgb * xgb_result['confidence'] + w_cnn * cnn_result['confidence']
|
||||
if score >= min_confidence:
|
||||
→ signal validé (beaucoup plus fiable)
|
||||
else:
|
||||
→ NEUTRAL (désaccord entre modèles)
|
||||
```
|
||||
|
||||
### Fichiers à créer — Ensemble
|
||||
|
||||
```
|
||||
src/ml/ensemble/
|
||||
├── __init__.py
|
||||
├── ensemble_model.py # Combine XGBoost + CNN (+ RL futur)
|
||||
└── ensemble_config.py # Poids configurables par défaut
|
||||
|
||||
src/strategies/ensemble/
|
||||
├── __init__.py
|
||||
└── ensemble_strategy.py # EnsembleStrategy (hérite BaseStrategy)
|
||||
```
|
||||
|
||||
### Routes API à ajouter — Ensemble
|
||||
|
||||
| Méthode | Route | Description |
|
||||
|---|---|---|
|
||||
| POST | `/trading/ensemble/configure` | Définir les poids (xgb/cnn) |
|
||||
| GET | `/trading/ensemble/signal` | Signal combiné en temps réel |
|
||||
| GET | `/trading/ensemble/status` | Statut de chaque modèle de l'ensemble |
|
||||
|
||||
### Configuration par défaut
|
||||
|
||||
```json
|
||||
{
|
||||
"weights": {
|
||||
"xgboost": 0.40,
|
||||
"cnn": 0.60
|
||||
},
|
||||
"min_confidence": 0.60,
|
||||
"require_agreement": true
|
||||
}
|
||||
```
|
||||
|
||||
*Poids CNN légèrement supérieurs car il voit les patterns bruts sans nos biais de feature engineering.*
|
||||
|
||||
---
|
||||
|
||||
## Phase 4d — RL (après 4c validée)
|
||||
|
||||
Implémentation après validation CNN + Ensemble en paper trading (≥ 2 semaines).
|
||||
|
||||
### Environnement RL
|
||||
|
||||
```
|
||||
Framework : gymnasium (OpenAI Gym successor)
|
||||
Algorithme : PPO (Proximal Policy Optimization) — stable et adapté au trading
|
||||
|
||||
État : [dernières 64 bougies OHLCV + features XGBoost + signal CNN]
|
||||
Action : {HOLD, LONG, SHORT, CLOSE}
|
||||
Récompense : PnL réel net de frais, avec pénalité sur drawdown
|
||||
|
||||
Entraînement : sur données historiques 3 ans (simulation)
|
||||
Validation : walk-forward + paper trading
|
||||
```
|
||||
|
||||
### Intégration dans l'ensemble
|
||||
|
||||
```python
|
||||
# Phase 4d : triplet
|
||||
score = 0.30 * xgb_conf + 0.40 * cnn_conf + 0.30 * rl_conf
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TODO — Phase 4c (par ordre)
|
||||
|
||||
### Étape 1 : Dépendance PyTorch
|
||||
- [ ] Ajouter `torch==2.1.0+cpu` dans `docker/requirements/api.txt`
|
||||
- [ ] Tester `docker compose build --no-cache trading-api` (peut prendre 10-15 min)
|
||||
|
||||
### Étape 2 : CNN core
|
||||
- [ ] `src/ml/cnn/candlestick_encoder.py` — normalisation z-score, padding, output shape (N, 64, 5)
|
||||
- [ ] `src/ml/cnn/cnn_model.py` — architecture PyTorch, forward(), train_epoch(), eval_epoch()
|
||||
- [ ] `src/ml/cnn/cnn_strategy_model.py` — train(), predict(), save(), load(), walk-forward eval
|
||||
|
||||
### Étape 3 : CNN strategy + API
|
||||
- [ ] `src/strategies/cnn_driven/cnn_strategy.py` — CNNDrivenStrategy
|
||||
- [ ] Routes `POST /trading/train-cnn`, `GET /trading/train-cnn/{job_id}`, `GET /trading/cnn-models`
|
||||
|
||||
### Étape 4 : Ensemble
|
||||
- [ ] `src/ml/ensemble/ensemble_model.py` — combine XGBoost + CNN
|
||||
- [ ] `src/strategies/ensemble/ensemble_strategy.py` — EnsembleStrategy
|
||||
- [ ] Routes `/trading/ensemble/*`
|
||||
|
||||
### Étape 5 : Validation
|
||||
- [ ] Entraîner CNN sur EURUSD/1h (2 ans)
|
||||
- [ ] Comparer backtest : Scalping vs XGBoost seul vs CNN seul vs Ensemble
|
||||
- [ ] Si Ensemble Sharpe > 0.8, démarrer paper trading ensemble
|
||||
|
||||
---
|
||||
|
||||
## Historique
|
||||
|
||||
| Date | Version | Description |
|
||||
|---|---|---|
|
||||
| 2026-03-08 | v0.1 | Plan initial — architecture CNN + Ensemble + RL (futur) |
|
||||
182
docs/CNN_IMAGE_PLAN.md
Normal file
182
docs/CNN_IMAGE_PLAN.md
Normal file
@@ -0,0 +1,182 @@
|
||||
# Phase 4c-bis : CNN Image-Based — Analyse Visuelle de Graphiques
|
||||
|
||||
**Date** : 2026-03-10
|
||||
**Statut** : 🟡 En développement
|
||||
**Prérequis** : Phase 4c (CNN 1D + Ensemble) ✅
|
||||
|
||||
---
|
||||
|
||||
## Concept
|
||||
|
||||
Contrairement au CNN 1D qui analyse des séquences numériques, le CNN image-based
|
||||
convertit les bougies OHLCV en **vraies images de graphiques** (rendu matplotlib),
|
||||
puis utilise un réseau de neurones de **vision par ordinateur** (Conv2D) pour
|
||||
reconnaître les patterns visuels — exactement comme un trader devant TradingView.
|
||||
|
||||
### Ce que le modèle apprend sans qu'on le programme
|
||||
- Bougies marteau, étoile filante, doji
|
||||
- Double top, double bottom
|
||||
- Rebonds sur support/résistance (zones de consolidation visibles)
|
||||
- Momentum : grande bougie pleine après consolidation
|
||||
- Divergences visibles (mouvement de prix vs volume en bas de l'image)
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
OHLCV DataFrame (64 bougies)
|
||||
│
|
||||
▼
|
||||
CandlestickImageRenderer
|
||||
- Rendu mplfinance (sans axes, sans texte)
|
||||
- Bougies vertes (hausse) / rouges (baisse)
|
||||
- Volume en bas (20% de l'image)
|
||||
- Image 128×128 pixels, RGB
|
||||
- Normalisation [0..1]
|
||||
│
|
||||
▼
|
||||
CandlestickCNN (Conv2D — vision)
|
||||
Input : (batch, 3, 128, 128)
|
||||
|
||||
Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm + ReLU + MaxPool(2,2) → (32, 64, 64)
|
||||
Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm + ReLU + MaxPool(2,2) → (64, 32, 32)
|
||||
Bloc 3 : Conv2d(64→128, 3×3) + BatchNorm + ReLU + MaxPool(2,2) → (128, 16, 16)
|
||||
Bloc 4 : Conv2d(128→256, 3×3) + BatchNorm + ReLU + AdaptiveAvgPool(1) → (256,)
|
||||
|
||||
Classifieur :
|
||||
Linear(256→128) + Dropout(0.4) + ReLU
|
||||
Linear(128→3) → softmax → [SHORT, NEUTRAL, LONG]
|
||||
│
|
||||
▼
|
||||
Signal : 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL)
|
||||
Confidence : max(P_LONG, P_SHORT)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Fichiers à créer
|
||||
|
||||
### src/ml/cnn_image/
|
||||
|
||||
| Fichier | Rôle |
|
||||
|---|---|
|
||||
| `__init__.py` | Export CNNImageStrategyModel |
|
||||
| `chart_renderer.py` | CandlestickImageRenderer : encode(df, seq_len=64) → (N, 3, 128, 128) |
|
||||
| `cnn_image_model.py` | CandlestickCNN(nn.Module) : Conv2D 4-blocs + Dense |
|
||||
| `cnn_image_strategy_model.py` | CNNImageStrategyModel : même interface que MLStrategyModel |
|
||||
|
||||
### src/strategies/cnn_image_driven/
|
||||
|
||||
| Fichier | Rôle |
|
||||
|---|---|
|
||||
| `__init__.py` | Export CNNImageDrivenStrategy |
|
||||
| `cnn_image_strategy.py` | CNNImageDrivenStrategy(BaseStrategy), SL/TP ATR-based |
|
||||
|
||||
### docker/requirements/api.txt
|
||||
- Ajouter `mplfinance>=0.12.10b0`
|
||||
- `Pillow>=10.0.0` (probablement déjà présent)
|
||||
- Rebuild trading-api
|
||||
|
||||
### src/api/routers/trading.py
|
||||
- POST `/trading/train-cnn-image`
|
||||
- GET `/trading/train-cnn-image/{job_id}`
|
||||
- GET `/trading/cnn-image-models`
|
||||
|
||||
### src/ml/ensemble/ensemble_model.py
|
||||
- Ajouter `attach_cnn_image(model)` comme 3ème slot
|
||||
- Mettre à jour `DEFAULT_WEIGHTS = {xgboost: 0.30, cnn: 0.30, cnn_image: 0.40, rl: 0.00}`
|
||||
- Mettre à jour `predict()` pour inclure le 3ème modèle
|
||||
|
||||
---
|
||||
|
||||
## CandlestickImageRenderer — Détails
|
||||
|
||||
```python
|
||||
class CandlestickImageRenderer:
|
||||
"""
|
||||
Convertit des données OHLCV en images de graphiques en chandeliers.
|
||||
|
||||
Paramètres d'image :
|
||||
- Taille : 128×128 pixels
|
||||
- Canaux : RGB (3)
|
||||
- Fond : noir (#0d1117)
|
||||
- Hausse : vert (#26a69a), Baisse : rouge (#ef5350)
|
||||
- Volume : dégradé alpha en bas (20% de l'image)
|
||||
- Pas d'axes, pas de labels, pas de titre
|
||||
"""
|
||||
|
||||
def encode(df, seq_len=64) -> np.ndarray:
|
||||
# Retourne (N, 3, 128, 128), float32, normalisé [0,1]
|
||||
# N = len(df) - seq_len + 1 fenêtres glissantes
|
||||
|
||||
def encode_last(df, seq_len=64) -> np.ndarray:
|
||||
# Retourne (1, 3, 128, 128) — dernière fenêtre uniquement
|
||||
# Utilisé pour la prédiction en temps réel
|
||||
|
||||
def _render_single(df_window) -> PIL.Image:
|
||||
# Rendu mplfinance en mémoire (BytesIO)
|
||||
# style custom : fond noir, pas d'axes
|
||||
# Retourne PIL.Image 128×128 RGB
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CNNImageStrategyModel — Interface
|
||||
|
||||
Identique à `MLStrategyModel` et `CNNStrategyModel` :
|
||||
|
||||
```python
|
||||
model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv) # walk-forward 2 folds
|
||||
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||
model.save() # models/cnn_image_strategy/EURUSD_1h.pt + .json
|
||||
model.load(symbol, timeframe) # chargement depuis disque
|
||||
model.list_trained_models() # liste des modèles disponibles
|
||||
model.get_feature_importance() # retourne [] (CNN = boîte noire)
|
||||
```
|
||||
|
||||
Paramètres entraînement :
|
||||
- `seq_len` : 64 bougies par image
|
||||
- `epochs` : 50 max, early stopping patience=7
|
||||
- `batch_size` : 32
|
||||
- `lr` : 1e-3 (Adam)
|
||||
- Labels : via LabelGenerator partagé (ATR-based, même que XGBoost et CNN 1D)
|
||||
|
||||
---
|
||||
|
||||
## Intégration Ensemble
|
||||
|
||||
Après cette phase, l'EnsembleModel aura 3 composants :
|
||||
|
||||
```
|
||||
Signal Ensemble =
|
||||
0.30 × XGBoost(TA features)
|
||||
+ 0.30 × CNN 1D(séquences OHLCV)
|
||||
+ 0.40 × CNN Image(graphiques visuels)
|
||||
|
||||
Condition de trade :
|
||||
- Score pondéré ≥ min_confidence (défaut 0.55)
|
||||
- Au moins 2 modèles en accord sur la direction
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TODOs
|
||||
|
||||
- [ ] chart_renderer.py — CandlestickImageRenderer avec mplfinance
|
||||
- [ ] cnn_image_model.py — CandlestickCNN Conv2D 4-blocs
|
||||
- [ ] cnn_image_strategy_model.py — train/predict/save/load
|
||||
- [ ] cnn_image_strategy.py — CNNImageDrivenStrategy(BaseStrategy)
|
||||
- [ ] trading.py — routes /train-cnn-image, /cnn-image-models
|
||||
- [ ] ensemble_model.py — attach_cnn_image(), poids mis à jour
|
||||
- [ ] requirements — mplfinance + rebuild Docker
|
||||
- [ ] Entraînement validé sur EURUSD/1h
|
||||
- [ ] Intégration EnsembleStrategy avec les 3 modèles
|
||||
|
||||
---
|
||||
|
||||
## Phase 4d — RL (après validation CNN image)
|
||||
|
||||
Agent PPO (Proximal Policy Optimization) via `gymnasium` + `stable-baselines3`.
|
||||
Voir docs/CNN_ENSEMBLE_PLAN.md section Phase 4d.
|
||||
@@ -1,7 +1,7 @@
|
||||
# État d'Avancement du Projet — Trading AI Secure
|
||||
|
||||
**Dernière mise à jour** : 2026-03-08
|
||||
**Version** : 0.5.0-beta
|
||||
**Dernière mise à jour** : 2026-03-10
|
||||
**Version** : 0.6.0-beta
|
||||
**Statut Global** : 🟡 En Développement Actif
|
||||
|
||||
---
|
||||
@@ -15,6 +15,8 @@
|
||||
| Phase 3 : Stratégies & Backtesting | ✅ Terminé | 100% |
|
||||
| Phase 4 : Interface & Dashboard | ✅ Terminé | 100% |
|
||||
| Phase 4b : ML-Driven Strategy | ✅ Terminé | 100% |
|
||||
| Phase 4c : CNN + Ensemble | ✅ Terminé | 100% |
|
||||
| Phase 4d : RL (PPO) | ⚪ Planifié | 0% |
|
||||
| Phase 5 : IG Markets (Live) | ⚪ Planifié | 0% |
|
||||
|
||||
---
|
||||
@@ -110,7 +112,7 @@ Voir [docs/ML_STRATEGY_GUIDE.md](ML_STRATEGY_GUIDE.md) pour la documentation com
|
||||
| Composant | Fichier | Statut |
|
||||
|---|---|---|
|
||||
| TechnicalFeatureBuilder (~50 features) | `src/ml/features/technical_features.py` | ✅ |
|
||||
| LabelGenerator (forward simulation) | `src/ml/features/label_generator.py` | ✅ |
|
||||
| LabelGenerator (forward simulation) | `src/ml/features/label_generator.py` | ✅ fix bug SHORT (2026-03-08) |
|
||||
| MLStrategyModel (XGBoost/LightGBM) | `src/ml/ml_strategy_model.py` | ✅ |
|
||||
| MLDrivenStrategy (hérite BaseStrategy) | `src/strategies/ml_driven/ml_strategy.py` | ✅ |
|
||||
| Route POST /trading/train | `src/api/routers/trading.py` | ✅ |
|
||||
@@ -121,6 +123,52 @@ Voir [docs/ML_STRATEGY_GUIDE.md](ML_STRATEGY_GUIDE.md) pour la documentation com
|
||||
|
||||
---
|
||||
|
||||
## Phase 4c — CNN + Ensemble ✅ (2026-03-10)
|
||||
|
||||
CNN 1D sur séquences brutes OHLCV + combinaison pondérée avec XGBoost.
|
||||
Voir [docs/CNN_ENSEMBLE_PLAN.md](CNN_ENSEMBLE_PLAN.md) pour l'architecture complète.
|
||||
|
||||
| Composant | Fichier | Statut |
|
||||
|---|---|---|
|
||||
| PyTorch 2.10 dans requirements | `docker/requirements/api.txt` | ✅ Installé |
|
||||
| CandlestickEncoder (normalisation séquences) | `src/ml/cnn/candlestick_encoder.py` | ✅ |
|
||||
| CNNModel (1D Conv PyTorch) | `src/ml/cnn/cnn_model.py` | ✅ |
|
||||
| CNNStrategyModel (train/predict/save/load) | `src/ml/cnn/cnn_strategy_model.py` | ✅ |
|
||||
| CNNDrivenStrategy (hérite BaseStrategy) | `src/strategies/cnn_driven/cnn_strategy.py` | ✅ |
|
||||
| Routes API CNN (POST /train-cnn, GET /cnn-models) | `src/api/routers/trading.py` | ✅ |
|
||||
| EnsembleModel (XGBoost + CNN pondérés) | `src/ml/ensemble/ensemble_model.py` | ✅ |
|
||||
| EnsembleStrategy (hérite BaseStrategy) | `src/strategies/ensemble/ensemble_strategy.py` | ✅ |
|
||||
| Routes API Ensemble (POST /configure, GET /status) | `src/api/routers/trading.py` | ✅ |
|
||||
| Entraînement XGBoost validé (EURUSD/1h) | — | ✅ wf_prec=21.7% |
|
||||
| Entraînement CNN validé (EURUSD/1h) | — | ✅ wf_prec=32.7% |
|
||||
| Backtest comparatif (Scalping vs XGBoost vs CNN vs Ensemble) | — | ⏸️ À faire |
|
||||
|
||||
### Résultats comparatifs des modèles (EURUSD/1h, 2 ans, ~12k barres)
|
||||
| Modèle | WF Accuracy | WF Precision | Labels (L/S/N) |
|
||||
|---|---|---|---|
|
||||
| XGBoost (features TA) | 33.6% | 21.7% | 3992/3943/4338 ✅ équilibré |
|
||||
| CNN 1D (séquences OHLCV) | 31.9% | **32.7%** | 3971/3936/4303 ✅ équilibré |
|
||||
| Ensemble (XGB+CNN 0.40/0.60) | — | — | Non encore testé |
|
||||
|
||||
**Observation** : CNN > XGBoost sur la précision directionnelle (32.7% vs 21.7%). Les deux modèles sont complémentaires pour l'ensemble.
|
||||
|
||||
### Bugs corrigés (2026-03-10 — session agents)
|
||||
- `trading.py` : `_get_data_service()` inexistant → instanciation directe DataService
|
||||
- `trading.py` : `logger` non défini dans except handler → ajout import logging
|
||||
- `trading.py` : `period` string mal converti → period_map identique à _run_optimize_task
|
||||
- `strategy_engine.py` : `ml_driven` non supporté dans `load_strategy()` → cas ajouté
|
||||
- `docker/requirements/api.txt` : dépendances ML (scikit-learn, xgboost, lightgbm) manquantes dans trading-api
|
||||
- `ml_strategy_model.py` : labels [-1,0,1] → encodage +1 pour XGBoost ≥ 2.x
|
||||
|
||||
---
|
||||
|
||||
## Phase 4d — RL (Planifié, après 4c validée)
|
||||
|
||||
Agent RL (PPO via gymnasium) intégré à l'ensemble comme troisième signal.
|
||||
Voir [docs/CNN_ENSEMBLE_PLAN.md](CNN_ENSEMBLE_PLAN.md) section Phase 4d.
|
||||
|
||||
---
|
||||
|
||||
## Routes API — État Complet
|
||||
|
||||
| Méthode | Route | Statut |
|
||||
|
||||
236
docs/RL_STRATEGY_GUIDE.md
Normal file
236
docs/RL_STRATEGY_GUIDE.md
Normal file
@@ -0,0 +1,236 @@
|
||||
# RL Strategy Guide — Phase 4d (Reinforcement Learning)
|
||||
|
||||
## Vue d'ensemble
|
||||
|
||||
La Phase 4d introduit une stratégie pilotée par un agent de **Reinforcement Learning (RL)**
|
||||
basé sur l'algorithme **PPO** (Proximal Policy Optimization). Contrairement aux stratégies
|
||||
supervisées (XGBoost, CNN) qui s'entraînent sur des labels pré-calculés, l'agent RL apprend
|
||||
directement par interaction avec un environnement de trading simulé (`TradingEnv`), sans avoir
|
||||
besoin d'annotations explicites LONG/SHORT/NEUTRAL.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### Composants (implémentés par l'agent ml/rl/)
|
||||
|
||||
```
|
||||
src/ml/rl/
|
||||
├── trading_env.py # TradingEnv (Gymnasium) — environnement de simulation
|
||||
├── ppo_model.py # PPOModel — réseau Actor-Critic (MLP)
|
||||
└── rl_strategy_model.py # RLStrategyModel — interface train/predict/save/load
|
||||
```
|
||||
|
||||
### Composants (cette phase)
|
||||
|
||||
```
|
||||
src/strategies/rl_driven/
|
||||
├── __init__.py # Export RLDrivenStrategy
|
||||
└── rl_strategy.py # RLDrivenStrategy (hérite BaseStrategy)
|
||||
```
|
||||
|
||||
### Pipeline d'entraînement PPO
|
||||
|
||||
```
|
||||
Données OHLCV
|
||||
│
|
||||
▼
|
||||
TradingEnv (Gymnasium)
|
||||
├── Observation space : fenêtre glissante seq_len=20 barres (OHLCV normalisé)
|
||||
├── Action space : {0=HOLD/NEUTRAL, 1=LONG, 2=SHORT}
|
||||
└── Reward : P&L réalisé − pénalité drawdown − pénalité over-trading
|
||||
│
|
||||
▼
|
||||
PPO Actor-Critic (MLP)
|
||||
├── Actor : policy π(a|s) → distribution sur les actions
|
||||
└── Critic : value function V(s) → estimation de la récompense future
|
||||
│
|
||||
▼
|
||||
Optimisation sur total_timesteps pas
|
||||
│
|
||||
▼
|
||||
RLStrategyModel.save()
|
||||
├── models/rl_strategy/EURUSD_1h.zip (politique PPO)
|
||||
└── models/rl_strategy/EURUSD_1h_meta.json
|
||||
```
|
||||
|
||||
### Architecture Actor-Critic
|
||||
|
||||
L'agent PPO utilise un réseau MLP (Multi-Layer Perceptron) à deux têtes :
|
||||
|
||||
- **Actor** : prédit la distribution de probabilité sur les 3 actions (LONG/SHORT/NEUTRAL).
|
||||
La confiance du signal correspond à `max(probas)`.
|
||||
- **Critic** : estime la valeur d'état V(s) pour calculer l'avantage (advantage) utilisé
|
||||
lors de la mise à jour de la politique.
|
||||
|
||||
Hyperparamètres PPO typiques :
|
||||
| Paramètre | Valeur par défaut | Description |
|
||||
|---------------|-------------------|----------------------------------------------|
|
||||
| `gamma` | 0.99 | Facteur d'actualisation des récompenses |
|
||||
| `clip_range` | 0.2 | Clipping ratio PPO (stabilité) |
|
||||
| `n_steps` | 2048 | Nombre de pas par rollout |
|
||||
| `batch_size` | 64 | Taille des mini-batches SGD |
|
||||
| `n_epochs` | 10 | Passes sur chaque rollout |
|
||||
| `ent_coef` | 0.01 | Coefficient d'entropie (exploration) |
|
||||
|
||||
---
|
||||
|
||||
## Lancer l'entraînement
|
||||
|
||||
### Via curl
|
||||
|
||||
```bash
|
||||
# Lancer l'entraînement (tâche de fond)
|
||||
curl -X POST http://localhost:8100/trading/train-rl \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"period": "2y",
|
||||
"total_timesteps": 50000
|
||||
}'
|
||||
|
||||
# Réponse
|
||||
{
|
||||
"job_id": "a1b2c3d4-...",
|
||||
"status": "pending",
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
# Suivre l'avancement
|
||||
curl http://localhost:8100/trading/train-rl/a1b2c3d4-...
|
||||
|
||||
# Réponse quand terminé
|
||||
{
|
||||
"job_id": "a1b2c3d4-...",
|
||||
"status": "completed",
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"avg_reward": 0.042,
|
||||
"sharpe_env": 1.23,
|
||||
"total_timesteps": 50000,
|
||||
"trained_at": "2026-03-10T14:32:00"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
# Lister les modèles disponibles
|
||||
curl http://localhost:8100/trading/rl-models
|
||||
|
||||
# Réponse
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"avg_reward": 0.042,
|
||||
"sharpe_env": 1.23,
|
||||
"total_timesteps": 50000,
|
||||
"trained_at": "2026-03-10T14:32:00"
|
||||
}
|
||||
],
|
||||
"count": 1
|
||||
}
|
||||
```
|
||||
|
||||
### Paramètres de la requête
|
||||
|
||||
| Champ | Type | Défaut | Description |
|
||||
|-------------------|-------|----------|---------------------------------------------------------|
|
||||
| `symbol` | str | EURUSD | Paire de trading (ex: EURUSD, BTCUSDT) |
|
||||
| `timeframe` | str | 1h | Timeframe (1m, 5m, 15m, 1h, 4h, 1d) |
|
||||
| `period` | str | 2y | Historique d'entraînement (ex: 6m, 1y, 2y) |
|
||||
| `total_timesteps` | int | 50000 | Nombre total de pas de simulation PPO |
|
||||
|
||||
**Recommandations `total_timesteps` :**
|
||||
- 20 000 : entraînement rapide (test, ~5 min CPU)
|
||||
- 50 000 : entraînement standard (défaut, ~15 min CPU)
|
||||
- 200 000 : entraînement long, meilleure convergence (~1h CPU)
|
||||
- 1 000 000 : entraînement poussé si GPU disponible
|
||||
|
||||
---
|
||||
|
||||
## Interpréter les métriques
|
||||
|
||||
### `avg_reward`
|
||||
Récompense moyenne par pas de simulation sur les derniers rollouts d'évaluation.
|
||||
|
||||
- **< 0** : l'agent perd de l'argent en simulation → entraîner plus longtemps ou revoir la fonction de récompense
|
||||
- **0 à 0.02** : agent neutre, légèrement profitable
|
||||
- **> 0.05** : bon signal → tester en backtest réel
|
||||
- **> 0.1** : excellent (attention au sur-apprentissage, vérifier out-of-sample)
|
||||
|
||||
### `sharpe_env`
|
||||
Ratio de Sharpe calculé sur les épisodes de simulation (récompenses / écart-type des récompenses).
|
||||
|
||||
- **< 0.5** : insuffisant pour paper trading
|
||||
- **0.5 – 1.0** : acceptable, à valider en backtest
|
||||
- **> 1.5** : cible pour activation paper trading (conforme aux seuils du projet)
|
||||
|
||||
### Interprétation combinée
|
||||
|
||||
| `avg_reward` | `sharpe_env` | Interprétation |
|
||||
|--------------|--------------|--------------------------------------------------------|
|
||||
| négatif | quelconque | Agent non convergé — relancer avec plus de timesteps |
|
||||
| 0 – 0.02 | < 1.0 | Apprentissage partiel — augmenter total_timesteps |
|
||||
| > 0.03 | > 1.0 | Bon candidat — valider via POST /trading/backtest |
|
||||
| > 0.05 | > 1.5 | Prêt pour paper trading (30 jours minimum) |
|
||||
|
||||
---
|
||||
|
||||
## Intégration avec le paper trading
|
||||
|
||||
Après l'entraînement, si un paper trading avec la stratégie `rl_driven` est actif,
|
||||
le modèle est **automatiquement attaché** sans redémarrage.
|
||||
|
||||
Pour démarrer un paper trading RL :
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8100/trading/paper/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"strategy": "rl_driven", "symbol": "EURUSD"}'
|
||||
```
|
||||
|
||||
La stratégie `RLDrivenStrategy` :
|
||||
1. Charge le dernier modèle entraîné pour le symbole/timeframe
|
||||
2. À chaque barre, fournit les 20 dernières bougies à l'agent (`seq_len=20`)
|
||||
3. Si la confiance de l'agent >= `min_confidence` (défaut: 0.55), émet un signal
|
||||
4. SL = `sl_atr_mult × ATR` (défaut: 1×ATR), TP = `tp_atr_mult × ATR` (défaut: 2×ATR)
|
||||
|
||||
---
|
||||
|
||||
## Notes techniques
|
||||
|
||||
### Threads PyTorch
|
||||
L'entraînement fixe `torch.set_num_threads(4)` pour éviter la contention CPU dans
|
||||
le container Docker. Adapter dans `docker-compose.yml` si le container dispose de plus de cœurs.
|
||||
|
||||
### Sauvegarde des modèles
|
||||
Les modèles sont sauvegardés dans `models/rl_strategy/` (volume Docker monté) :
|
||||
- `EURUSD_1h.zip` : politique PPO (format stable-baselines3)
|
||||
- `EURUSD_1h_meta.json` : métadonnées (métriques, hyperparamètres, date)
|
||||
|
||||
### Import conditionnel
|
||||
Le flag `RL_AVAILABLE` est `False` si PyTorch ou stable-baselines3 ne sont pas
|
||||
installés. La stratégie dégrade gracieusement (aucun signal, aucune exception).
|
||||
Pour activer : ajouter `stable-baselines3` et `gymnasium` dans `docker/requirements/api.txt`
|
||||
puis reconstruire le container (`docker compose build --no-cache trading-api`).
|
||||
|
||||
---
|
||||
|
||||
## Seuils de validation (conforme au projet)
|
||||
|
||||
Avant activation du live trading, la stratégie RL doit satisfaire :
|
||||
|
||||
| Métrique | Seuil minimum |
|
||||
|-----------------|---------------|
|
||||
| Sharpe Ratio | ≥ 1.5 |
|
||||
| Max Drawdown | ≤ 10% |
|
||||
| Win Rate | ≥ 55% |
|
||||
| Paper Trading | ≥ 30 jours |
|
||||
|
||||
Ces seuils s'appliquent au **backtest out-of-sample** et au **paper trading**, pas
|
||||
aux métriques de simulation RL (`sharpe_env`) qui sont indicatives uniquement.
|
||||
365
scripts/compare_strategies.py
Executable file
365
scripts/compare_strategies.py
Executable file
@@ -0,0 +1,365 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
compare_strategies.py — Backtest comparatif automatisé des stratégies de trading.
|
||||
|
||||
Lance POST /trading/backtest pour chaque stratégie (scalping, ml_driven, cnn_driven)
|
||||
sur le même dataset (EURUSD / 1h / 1y), poll jusqu'à completion, affiche un tableau
|
||||
comparatif et sauvegarde le rapport JSON dans reports/.
|
||||
|
||||
Usage :
|
||||
python scripts/compare_strategies.py
|
||||
python scripts/compare_strategies.py --symbol GBPUSD --period 2y --capital 20000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vérification de la dépendance httpx (ou fallback vers requests)
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import httpx
|
||||
_HTTP_BACKEND = "httpx"
|
||||
except ImportError:
|
||||
try:
|
||||
import requests as _requests_module
|
||||
_HTTP_BACKEND = "requests"
|
||||
except ImportError:
|
||||
print("ERREUR : Ni httpx ni requests ne sont installés. Exécutez : pip install httpx")
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
API_BASE_URL = "http://localhost:8100"
|
||||
POLL_INTERVAL_SEC = 5 # Intervalle de polling (secondes)
|
||||
TIMEOUT_SEC = 300 # Timeout maximum par job (5 minutes)
|
||||
REPORTS_DIR = Path(__file__).parent.parent / "reports"
|
||||
|
||||
# Stratégies à comparer — dans l'ordre d'affichage.
|
||||
# NOTE : l'API (POST /trading/backtest) accepte : scalping | intraday | swing | ml_driven
|
||||
# La stratégie cnn_driven n'est pas encore intégrée dans le pipeline de backtest
|
||||
# (elle est disponible uniquement via paper trading). Elle sera incluse automatiquement
|
||||
# si l'API évolue pour la supporter.
|
||||
STRATEGIES = [
|
||||
"scalping",
|
||||
"ml_driven",
|
||||
"cnn_driven", # Retournera une erreur 400 si non supporté par l'API — géré gracieusement
|
||||
]
|
||||
|
||||
# Colonnes métriques à collecter et afficher
|
||||
METRICS_COLS = [
|
||||
("total_return", "Retour total", "{:.2%}"),
|
||||
("sharpe_ratio", "Sharpe ratio", "{:.3f}"),
|
||||
("max_drawdown", "Max drawdown", "{:.2%}"),
|
||||
("win_rate", "Win rate", "{:.2%}"),
|
||||
("total_trades", "Trades totaux", "{:d}"),
|
||||
("profit_factor", "Profit factor", "{:.2f}"),
|
||||
]
|
||||
|
||||
# Seuils de validation (identiques à ceux de l'API)
|
||||
SEUIL_SHARPE = 1.5
|
||||
SEUIL_DRAWDOWN_MAX = 0.10
|
||||
SEUIL_WIN_RATE = 0.55
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Couche HTTP (httpx ou requests selon disponibilité)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _post(url: str, payload: dict, timeout: int = 30) -> dict:
|
||||
"""Effectue un POST JSON et retourne la réponse parsée."""
|
||||
if _HTTP_BACKEND == "httpx":
|
||||
resp = httpx.post(url, json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
else:
|
||||
resp = _requests_module.post(url, json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _get(url: str, timeout: int = 15) -> dict:
|
||||
"""Effectue un GET et retourne la réponse parsée."""
|
||||
if _HTTP_BACKEND == "httpx":
|
||||
resp = httpx.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
else:
|
||||
resp = _requests_module.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fonctions backtest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def lancer_backtest(strategy: str, symbol: str, period: str, capital: float) -> str:
|
||||
"""
|
||||
Lance un job de backtest via POST /trading/backtest.
|
||||
|
||||
Retourne le job_id ou lève une exception en cas d'erreur.
|
||||
"""
|
||||
payload = {
|
||||
"strategy": strategy,
|
||||
"symbol": symbol,
|
||||
"period": period,
|
||||
"initial_capital": capital,
|
||||
}
|
||||
url = f"{API_BASE_URL}/trading/backtest"
|
||||
data = _post(url, payload)
|
||||
return data["job_id"]
|
||||
|
||||
|
||||
def attendre_resultat(job_id: str, strategy: str) -> dict:
|
||||
"""
|
||||
Poll GET /trading/backtest/{job_id} toutes les POLL_INTERVAL_SEC secondes
|
||||
jusqu'à ce que le statut soit 'completed' ou 'failed'.
|
||||
|
||||
Retourne le dict de résultat ou lève une exception si le job échoue / dépasse le timeout.
|
||||
"""
|
||||
url = f"{API_BASE_URL}/trading/backtest/{job_id}"
|
||||
debut = time.time()
|
||||
|
||||
while True:
|
||||
elapsed = time.time() - debut
|
||||
if elapsed > TIMEOUT_SEC:
|
||||
raise TimeoutError(
|
||||
f"[{strategy}] Timeout atteint ({TIMEOUT_SEC}s). Job {job_id} toujours en cours."
|
||||
)
|
||||
|
||||
data = _get(url)
|
||||
statut = data.get("status", "?")
|
||||
|
||||
print(f" [{strategy}] Statut : {statut} ({elapsed:.0f}s)")
|
||||
|
||||
if statut == "completed":
|
||||
return data
|
||||
elif statut == "failed":
|
||||
erreur = data.get("error", "raison inconnue")
|
||||
raise RuntimeError(f"[{strategy}] Job échoué : {erreur}")
|
||||
|
||||
# Toujours en cours — on attend
|
||||
time.sleep(POLL_INTERVAL_SEC)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Affichage du tableau comparatif
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def afficher_tableau(resultats: list[dict]) -> None:
|
||||
"""
|
||||
Affiche un tableau comparatif dans le terminal.
|
||||
Tente d'utiliser 'tabulate' pour un meilleur rendu ; fallback vers print simple.
|
||||
"""
|
||||
# Construction des données tabulaires
|
||||
headers = ["Stratégie"] + [label for _, label, _ in METRICS_COLS] + ["Valide ?"]
|
||||
rows = []
|
||||
|
||||
for res in resultats:
|
||||
nom = res["strategy"]
|
||||
ligne = [nom]
|
||||
for key, _, fmt in METRICS_COLS:
|
||||
val = res.get(key)
|
||||
if val is None:
|
||||
ligne.append("N/A")
|
||||
else:
|
||||
try:
|
||||
if key == "total_trades":
|
||||
ligne.append(fmt.format(int(val)))
|
||||
else:
|
||||
ligne.append(fmt.format(float(val)))
|
||||
except (ValueError, TypeError):
|
||||
ligne.append(str(val))
|
||||
# Indicateur de validation
|
||||
valide = res.get("is_valid_for_paper")
|
||||
if valide is True:
|
||||
ligne.append("OUI")
|
||||
elif valide is False:
|
||||
ligne.append("NON")
|
||||
else:
|
||||
ligne.append("?")
|
||||
rows.append(ligne)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("RAPPORT COMPARATIF DES STRATÉGIES")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||
except ImportError:
|
||||
# Fallback : affichage simple sans tabulate
|
||||
col_widths = [max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||
for i, h in enumerate(headers)]
|
||||
sep = " ".join("-" * w for w in col_widths)
|
||||
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||
print(header_line)
|
||||
print(sep)
|
||||
for row in rows:
|
||||
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||
|
||||
print()
|
||||
|
||||
# Identification de la meilleure stratégie (selon Sharpe ratio)
|
||||
valides = [r for r in resultats if r.get("sharpe_ratio") is not None]
|
||||
if valides:
|
||||
meilleure = max(valides, key=lambda r: r.get("sharpe_ratio") or -999)
|
||||
print(f"Meilleure stratégie (Sharpe) : {meilleure['strategy']} "
|
||||
f"(Sharpe={meilleure.get('sharpe_ratio', 'N/A'):.3f})")
|
||||
|
||||
# Rappel des seuils
|
||||
print()
|
||||
print(f"Seuils de validation : Sharpe >= {SEUIL_SHARPE} | "
|
||||
f"Max drawdown <= {SEUIL_DRAWDOWN_MAX:.0%} | "
|
||||
f"Win rate >= {SEUIL_WIN_RATE:.0%}")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sauvegarde du rapport JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sauvegarder_rapport(resultats: list[dict], symbol: str, period: str) -> Path:
|
||||
"""
|
||||
Sauvegarde le rapport de comparaison en JSON dans reports/.
|
||||
|
||||
Retourne le chemin du fichier créé.
|
||||
"""
|
||||
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
horodatage = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
nom_fichier = f"backtest_comparison_{horodatage}.json"
|
||||
chemin = REPORTS_DIR / nom_fichier
|
||||
|
||||
rapport = {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"parametres": {
|
||||
"symbol": symbol,
|
||||
"period": period,
|
||||
"api_base_url": API_BASE_URL,
|
||||
},
|
||||
"seuils_validation": {
|
||||
"sharpe_ratio_min": SEUIL_SHARPE,
|
||||
"max_drawdown_max": SEUIL_DRAWDOWN_MAX,
|
||||
"win_rate_min": SEUIL_WIN_RATE,
|
||||
},
|
||||
"resultats": resultats,
|
||||
}
|
||||
|
||||
with open(chemin, "w", encoding="utf-8") as f:
|
||||
json.dump(rapport, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
return chemin
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Point d'entrée principal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse les arguments en ligne de commande."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compare les stratégies de trading via backtest API."
|
||||
)
|
||||
parser.add_argument("--symbol", default="EURUSD", help="Paire de trading (défaut: EURUSD)")
|
||||
parser.add_argument("--period", default="1y", help="Période historique : 6m | 1y | 2y (défaut: 1y)")
|
||||
parser.add_argument("--capital", default=10000.0, type=float, help="Capital initial (défaut: 10000)")
|
||||
parser.add_argument(
|
||||
"--strategies",
|
||||
nargs="+",
|
||||
default=STRATEGIES,
|
||||
help="Liste des stratégies à comparer (défaut: scalping ml_driven cnn_driven)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("BACKTEST COMPARATIF DES STRATÉGIES")
|
||||
print(f"Symbol: {args.symbol} | Période: {args.period} | Capital: {args.capital:,.0f} EUR")
|
||||
print(f"Stratégies: {', '.join(args.strategies)}")
|
||||
print(f"API: {API_BASE_URL}")
|
||||
print("=" * 70)
|
||||
|
||||
# Vérification de la disponibilité de l'API
|
||||
try:
|
||||
_get(f"{API_BASE_URL}/health")
|
||||
print("API disponible.")
|
||||
except Exception as e:
|
||||
print(f"ERREUR : Impossible de joindre l'API ({e})")
|
||||
print("Vérifiez que le container trading-api tourne sur le port 8100.")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1 : lancement de tous les backtests
|
||||
# -----------------------------------------------------------------------
|
||||
jobs: dict[str, str] = {} # {strategy: job_id}
|
||||
stratégies_échouées: list[str] = []
|
||||
|
||||
for strat in args.strategies:
|
||||
print(f"Lancement backtest [{strat}]...")
|
||||
try:
|
||||
job_id = lancer_backtest(strat, args.symbol, args.period, args.capital)
|
||||
jobs[strat] = job_id
|
||||
print(f" -> Job ID : {job_id}")
|
||||
except Exception as e:
|
||||
print(f" ERREUR lancement [{strat}] : {e}")
|
||||
stratégies_échouées.append(strat)
|
||||
|
||||
if not jobs:
|
||||
print("Aucun job lancé. Arrêt.")
|
||||
sys.exit(1)
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 2 : attente et collecte des résultats (séquentielle)
|
||||
# -----------------------------------------------------------------------
|
||||
resultats: list[dict] = []
|
||||
|
||||
for strat, job_id in jobs.items():
|
||||
print(f"Attente résultat [{strat}] (job={job_id})...")
|
||||
try:
|
||||
res = attendre_resultat(job_id, strat)
|
||||
resultats.append(res)
|
||||
print(f" -> Terminé. Sharpe={res.get('sharpe_ratio', 'N/A')}, "
|
||||
f"Retour={res.get('total_return', 'N/A')}")
|
||||
except Exception as e:
|
||||
print(f" ERREUR [{strat}] : {e}")
|
||||
# On ajoute quand même un résultat partiel pour la lisibilité du rapport
|
||||
resultats.append({
|
||||
"strategy": strat,
|
||||
"symbol": args.symbol,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 3 : affichage et sauvegarde
|
||||
# -----------------------------------------------------------------------
|
||||
afficher_tableau(resultats)
|
||||
|
||||
chemin_rapport = sauvegarder_rapport(resultats, args.symbol, args.period)
|
||||
print(f"Rapport JSON sauvegardé : {chemin_rapport}")
|
||||
print()
|
||||
|
||||
# Résumé des échecs éventuels
|
||||
if stratégies_échouées:
|
||||
print(f"Stratégies non lancées (erreur API) : {', '.join(stratégies_échouées)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
327
scripts/quick_benchmark.py
Executable file
327
scripts/quick_benchmark.py
Executable file
@@ -0,0 +1,327 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
quick_benchmark.py — Benchmark rapide des modèles ML entraînés.
|
||||
|
||||
Lit les fichiers *_meta.json dans les répertoires de modèles :
|
||||
- models/ml_strategy/ (XGBoost / LightGBM / RandomForest)
|
||||
- models/cnn_strategy/ (CNN 1D — CandlestickEncoder)
|
||||
- models/cnn_image_strategy/ (CNN Image — chandeliers en image)
|
||||
|
||||
Affiche un tableau comparatif : wf_accuracy, wf_precision, n_samples, etc.
|
||||
Indique quel modèle obtient la meilleure précision walk-forward.
|
||||
|
||||
Usage :
|
||||
python scripts/quick_benchmark.py
|
||||
python scripts/quick_benchmark.py --models-root /chemin/vers/models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Répertoires de modèles (relatifs à la racine du projet)
|
||||
# ---------------------------------------------------------------------------
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
|
||||
# Mapping : nom affiché -> chemin relatif au projet
|
||||
MODEL_DIRS = {
|
||||
"ML-Strategy (XGBoost/LightGBM)": PROJECT_ROOT / "models" / "ml_strategy",
|
||||
"CNN-Strategy (1D)": PROJECT_ROOT / "models" / "cnn_strategy",
|
||||
"CNN-Image-Strategy": PROJECT_ROOT / "models" / "cnn_image_strategy",
|
||||
}
|
||||
|
||||
# Colonnes de métriques affichées dans le tableau
|
||||
COLONNES = [
|
||||
("type", "Type", "{:<28}"),
|
||||
("symbol", "Symbol", "{:<8}"),
|
||||
("timeframe", "TF", "{:<5}"),
|
||||
("model_type", "Modèle", "{:<12}"),
|
||||
("n_samples", "Échantillons", "{:>12}"),
|
||||
("wf_accuracy", "WF Accuracy", "{:>12}"),
|
||||
("wf_precision", "WF Precision", "{:>13}"),
|
||||
("trained_at", "Entraîné le", "{:<20}"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lecture des méta-fichiers JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def charger_modeles(model_dirs: dict[str, Path]) -> list[dict]:
|
||||
"""
|
||||
Parcourt les répertoires de modèles et charge les métadonnées JSON.
|
||||
|
||||
Retourne une liste de dicts prêts pour l'affichage.
|
||||
"""
|
||||
tous = []
|
||||
|
||||
for type_label, dossier in model_dirs.items():
|
||||
if not dossier.exists():
|
||||
# Répertoire absent — pas encore de modèles entraînés pour ce type
|
||||
continue
|
||||
|
||||
fichiers_meta = sorted(dossier.glob("*_meta.json"))
|
||||
if not fichiers_meta:
|
||||
continue
|
||||
|
||||
for f in fichiers_meta:
|
||||
try:
|
||||
with open(f, encoding="utf-8") as fp:
|
||||
meta = json.load(fp)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"[WARN] Impossible de lire {f} : {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
wf = meta.get("wf_metrics", {})
|
||||
|
||||
# Extraction des champs — avec valeurs par défaut
|
||||
# Le nom du fichier (ex : EURUSD_1h_xgboost_meta.json) est utilisé
|
||||
# comme fallback si les champs ne sont pas dans le JSON.
|
||||
stem_parts = f.stem.replace("_meta", "").split("_")
|
||||
symbol = meta.get("symbol", stem_parts[0] if len(stem_parts) > 0 else "?")
|
||||
timeframe = meta.get("timeframe", stem_parts[1] if len(stem_parts) > 1 else "?")
|
||||
model_type = meta.get(
|
||||
"model_type",
|
||||
stem_parts[2] if len(stem_parts) > 2 else dossier.name
|
||||
)
|
||||
|
||||
wf_accuracy = wf.get("avg_accuracy", None)
|
||||
wf_precision = wf.get("avg_precision", None)
|
||||
|
||||
trained_at = meta.get("trained_at", "?")
|
||||
# Tronque la date ISO à 19 caractères pour la lisibilité
|
||||
if isinstance(trained_at, str) and len(trained_at) > 19:
|
||||
trained_at = trained_at[:19]
|
||||
|
||||
tous.append({
|
||||
"type": type_label,
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"model_type": model_type,
|
||||
"n_samples": meta.get("n_samples", 0),
|
||||
"wf_accuracy": wf_accuracy,
|
||||
"wf_precision": wf_precision,
|
||||
"trained_at": trained_at,
|
||||
"_meta_path": str(f), # pour le rapport détaillé éventuel
|
||||
"_raw_meta": meta,
|
||||
})
|
||||
|
||||
return tous
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Affichage du tableau
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def formater_val(val, fmt: str) -> str:
|
||||
"""Formate une valeur numérique ou retourne 'N/A'."""
|
||||
if val is None:
|
||||
return "N/A"
|
||||
try:
|
||||
if "%" in fmt or "f" in fmt:
|
||||
return f"{float(val):.2%}" if "%" in fmt else fmt.format(float(val))
|
||||
return fmt.format(val)
|
||||
except (ValueError, TypeError):
|
||||
return str(val)
|
||||
|
||||
|
||||
def afficher_tableau(modeles: list[dict]) -> None:
|
||||
"""Affiche le tableau comparatif des modèles dans le terminal."""
|
||||
if not modeles:
|
||||
print("Aucun modèle entraîné trouvé.")
|
||||
print("Entraînez d'abord un modèle via POST /trading/train (ou /train-cnn).")
|
||||
return
|
||||
|
||||
# Entêtes
|
||||
headers = [label for _, label, _ in COLONNES]
|
||||
col_fmts = [fmt for _, _, fmt in COLONNES]
|
||||
|
||||
# Construction des lignes
|
||||
rows = []
|
||||
for m in modeles:
|
||||
ligne = []
|
||||
for key, _, fmt in COLONNES:
|
||||
val = m.get(key)
|
||||
if val is None:
|
||||
ligne.append("N/A")
|
||||
elif key in ("wf_accuracy", "wf_precision"):
|
||||
# Affichage en pourcentage
|
||||
try:
|
||||
ligne.append(f"{float(val):.2%}")
|
||||
except (ValueError, TypeError):
|
||||
ligne.append("N/A")
|
||||
elif key == "n_samples":
|
||||
try:
|
||||
ligne.append(f"{int(val):,}")
|
||||
except (ValueError, TypeError):
|
||||
ligne.append(str(val))
|
||||
else:
|
||||
ligne.append(str(val))
|
||||
rows.append(ligne)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("BENCHMARK DES MODÈLES ML ENTRAÎNÉS")
|
||||
print("=" * 80)
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
print(tabulate(rows, headers=headers, tablefmt="rounded_outline"))
|
||||
except ImportError:
|
||||
# Fallback : affichage simple
|
||||
col_widths = [
|
||||
max(len(str(h)), max((len(str(r[i])) for r in rows), default=0))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
sep = " ".join("-" * w for w in col_widths)
|
||||
header_line = " ".join(str(h).ljust(w) for h, w in zip(headers, col_widths))
|
||||
print(header_line)
|
||||
print(sep)
|
||||
for row in rows:
|
||||
print(" ".join(str(c).ljust(w) for c, w in zip(row, col_widths)))
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Identification du meilleur modèle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def identifier_meilleur(modeles: list[dict]) -> None:
|
||||
"""Indique quel modèle est le meilleur selon wf_accuracy et wf_precision."""
|
||||
candidats = [m for m in modeles if m.get("wf_accuracy") is not None]
|
||||
if not candidats:
|
||||
print("Impossible de déterminer le meilleur modèle (aucune métrique wf_accuracy).")
|
||||
return
|
||||
|
||||
# Critère principal : wf_accuracy ; critère secondaire : wf_precision
|
||||
meilleur = max(
|
||||
candidats,
|
||||
key=lambda m: (
|
||||
m.get("wf_accuracy") or 0.0,
|
||||
m.get("wf_precision") or 0.0,
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Meilleur modèle (WF Accuracy) :")
|
||||
print(f" Type : {meilleur['type']}")
|
||||
print(f" Symbol : {meilleur['symbol']} / {meilleur['timeframe']}")
|
||||
print(f" Modèle : {meilleur['model_type']}")
|
||||
|
||||
acc = meilleur.get("wf_accuracy")
|
||||
prec = meilleur.get("wf_precision")
|
||||
print(f" WF Accuracy : {acc:.2%}" if acc is not None else " WF Accuracy : N/A")
|
||||
print(f" WF Precision : {prec:.2%}" if prec is not None else " WF Precision : N/A")
|
||||
print(f" Entraîné le : {meilleur.get('trained_at', '?')}")
|
||||
print(f" Fichier meta : {meilleur['_meta_path']}")
|
||||
|
||||
# Avertissement si la précision est insuffisante pour le trading
|
||||
if acc is not None and acc < 0.40:
|
||||
print()
|
||||
print("[AVIS] WF Accuracy < 40% — ce modèle est insuffisant pour trader seul.")
|
||||
print(" Envisagez un re-entraînement avec plus de données ou de features.")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Statistiques globales par type de modèle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def afficher_stats_globales(modeles: list[dict]) -> None:
|
||||
"""Affiche des statistiques agrégées par type de modèle."""
|
||||
if not modeles:
|
||||
return
|
||||
|
||||
# Regroupement par type
|
||||
par_type: dict[str, list] = {}
|
||||
for m in modeles:
|
||||
par_type.setdefault(m["type"], []).append(m)
|
||||
|
||||
print("Résumé par type de modèle :")
|
||||
print("-" * 50)
|
||||
for type_label, groupe in par_type.items():
|
||||
accs = [m["wf_accuracy"] for m in groupe if m.get("wf_accuracy") is not None]
|
||||
precs = [m["wf_precision"] for m in groupe if m.get("wf_precision") is not None]
|
||||
n_tot = sum(m.get("n_samples", 0) for m in groupe)
|
||||
|
||||
acc_moy = sum(accs) / len(accs) if accs else None
|
||||
prec_moy = sum(precs) / len(precs) if precs else None
|
||||
|
||||
acc_str = f"{acc_moy:.2%}" if acc_moy is not None else "N/A"
|
||||
prec_str = f"{prec_moy:.2%}" if prec_moy is not None else "N/A"
|
||||
|
||||
print(f" {type_label}")
|
||||
print(f" Modèles entraînés : {len(groupe)}")
|
||||
print(f" WF Accuracy moy. : {acc_str}")
|
||||
print(f" WF Precision moy. : {prec_str}")
|
||||
print(f" Échantillons tot. : {n_tot:,}")
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Point d'entrée principal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse les arguments en ligne de commande."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark rapide des modèles ML entraînés (lecture des méta-fichiers JSON)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models-root",
|
||||
type=Path,
|
||||
default=PROJECT_ROOT / "models",
|
||||
help=f"Répertoire racine des modèles (défaut: {PROJECT_ROOT / 'models'})",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
# Mise à jour des chemins si --models-root est spécifié
|
||||
dirs_effectifs = {
|
||||
label: args.models_root / dossier.name
|
||||
for label, dossier in MODEL_DIRS.items()
|
||||
}
|
||||
|
||||
print()
|
||||
print("Lecture des modèles dans :")
|
||||
for label, chemin in dirs_effectifs.items():
|
||||
existe = "OK" if chemin.exists() else "absent"
|
||||
print(f" [{existe}] {chemin}")
|
||||
print()
|
||||
|
||||
# Chargement de tous les modèles
|
||||
modeles = charger_modeles(dirs_effectifs)
|
||||
|
||||
if not modeles:
|
||||
print("Aucun modèle trouvé dans les répertoires indiqués.")
|
||||
print()
|
||||
print("Pour entraîner un modèle XGBoost/LightGBM :")
|
||||
print(" POST http://localhost:8100/trading/train")
|
||||
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\", \"model_type\": \"xgboost\"}")
|
||||
print()
|
||||
print("Pour entraîner un modèle CNN 1D :")
|
||||
print(" POST http://localhost:8100/trading/train-cnn")
|
||||
print(" {\"symbol\": \"EURUSD\", \"timeframe\": \"1h\"}")
|
||||
sys.exit(0)
|
||||
|
||||
# Affichage du tableau
|
||||
afficher_tableau(modeles)
|
||||
|
||||
# Statistiques par type
|
||||
afficher_stats_globales(modeles)
|
||||
|
||||
# Meilleur modèle
|
||||
identifier_meilleur(modeles)
|
||||
|
||||
print(f"Total : {len(modeles)} modèle(s) trouvé(s).")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -84,6 +84,9 @@ class StrategyEngine:
|
||||
elif strategy_name == 'swing':
|
||||
from src.strategies.swing.swing_strategy import SwingStrategy
|
||||
strategy_class = SwingStrategy
|
||||
elif strategy_name == 'ml_driven':
|
||||
from src.strategies.ml_driven.ml_strategy import MLDrivenStrategy
|
||||
strategy_class = MLDrivenStrategy
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy_name}")
|
||||
|
||||
|
||||
3
src/ml/cnn/__init__.py
Normal file
3
src/ml/cnn/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .cnn_strategy_model import CNNStrategyModel
|
||||
|
||||
__all__ = ['CNNStrategyModel']
|
||||
129
src/ml/cnn/candlestick_encoder.py
Normal file
129
src/ml/cnn/candlestick_encoder.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Encodeur de bougies OHLCV en séquences normalisées pour CNN 1D.
|
||||
|
||||
Transforme un DataFrame OHLCV en tenseurs (N, seq_len, 5) prêts pour le CNN.
|
||||
Chaque séquence est normalisée indépendamment (z-score glissant) pour que
|
||||
le modèle apprenne des patterns relatifs, pas des niveaux de prix absolus.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CandlestickEncoder:
|
||||
"""
|
||||
Encode les données OHLCV brutes en séquences normalisées pour le CNN.
|
||||
|
||||
La normalisation par fenêtre glissante garantit que le CNN voit des
|
||||
patterns de forme (doji, engulfing, etc.) indépendamment du niveau de prix.
|
||||
Le volume est normalisé séparément (ratio vs moyenne).
|
||||
|
||||
Args:
|
||||
seq_len: Longueur des séquences (nombre de bougies par échantillon)
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len: int = 64):
|
||||
self.seq_len = seq_len
|
||||
|
||||
def encode(self, df_ohlcv: pd.DataFrame, seq_len: int = None) -> np.ndarray:
|
||||
"""
|
||||
Encode toutes les séquences glissantes depuis le DataFrame OHLCV.
|
||||
|
||||
Args:
|
||||
df_ohlcv: DataFrame avec colonnes open, high, low, close, volume
|
||||
seq_len: Override de la longueur de séquence (optionnel)
|
||||
|
||||
Returns:
|
||||
np.ndarray de shape (N_samples, seq_len, 5)
|
||||
Colonnes : [open, high, low, close, volume]
|
||||
"""
|
||||
seq_len = seq_len or self.seq_len
|
||||
df = self._prepare_df(df_ohlcv)
|
||||
|
||||
if len(df) < seq_len + 1:
|
||||
logger.warning(f"Pas assez de données : {len(df)} barres < {seq_len + 1} minimum")
|
||||
return np.empty((0, seq_len, 5))
|
||||
|
||||
n_samples = len(df) - seq_len + 1
|
||||
sequences = np.zeros((n_samples, seq_len, 5), dtype=np.float32)
|
||||
|
||||
for i in range(n_samples):
|
||||
window = df.iloc[i:i + seq_len]
|
||||
sequences[i] = self._normalize_window(window)
|
||||
|
||||
# Supprimer les séquences avec NaN
|
||||
valid_mask = ~np.isnan(sequences).any(axis=(1, 2))
|
||||
sequences = sequences[valid_mask]
|
||||
|
||||
logger.info(f"Encodage : {len(sequences)} séquences de {seq_len} bougies")
|
||||
return sequences
|
||||
|
||||
def encode_last(self, df_ohlcv: pd.DataFrame, seq_len: int = None) -> np.ndarray:
|
||||
"""
|
||||
Encode uniquement la dernière séquence (pour prédiction temps réel).
|
||||
|
||||
Args:
|
||||
df_ohlcv: DataFrame OHLCV (au moins seq_len barres)
|
||||
seq_len: Override de la longueur de séquence (optionnel)
|
||||
|
||||
Returns:
|
||||
np.ndarray de shape (1, seq_len, 5)
|
||||
"""
|
||||
seq_len = seq_len or self.seq_len
|
||||
df = self._prepare_df(df_ohlcv)
|
||||
|
||||
if len(df) < seq_len:
|
||||
logger.warning(f"Pas assez de données pour encode_last : {len(df)} < {seq_len}")
|
||||
return np.empty((0, seq_len, 5))
|
||||
|
||||
window = df.iloc[-seq_len:]
|
||||
normalized = self._normalize_window(window)
|
||||
return normalized.reshape(1, seq_len, 5)
|
||||
|
||||
def _prepare_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prépare le DataFrame : colonnes en minuscules, sélection OHLCV."""
|
||||
df = df.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
required = ['open', 'high', 'low', 'close', 'volume']
|
||||
missing = [c for c in required if c not in df.columns]
|
||||
if missing:
|
||||
raise ValueError(f"Colonnes manquantes : {missing}")
|
||||
|
||||
return df[required].reset_index(drop=True)
|
||||
|
||||
def _normalize_window(self, window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Normalise une fenêtre OHLCV.
|
||||
|
||||
Prix (OHLC) : z-score sur la fenêtre (moyenne et écart-type du close)
|
||||
Volume : ratio par rapport à la moyenne du volume sur la fenêtre
|
||||
"""
|
||||
result = np.zeros((len(window), 5), dtype=np.float32)
|
||||
|
||||
# Normalisation prix par z-score du close
|
||||
close_values = window['close'].values.astype(np.float64)
|
||||
mean_price = close_values.mean()
|
||||
std_price = close_values.std()
|
||||
|
||||
if std_price < 1e-10:
|
||||
# Prix constant → tout à zéro
|
||||
std_price = 1.0
|
||||
|
||||
result[:, 0] = (window['open'].values - mean_price) / std_price
|
||||
result[:, 1] = (window['high'].values - mean_price) / std_price
|
||||
result[:, 2] = (window['low'].values - mean_price) / std_price
|
||||
result[:, 3] = (window['close'].values - mean_price) / std_price
|
||||
|
||||
# Normalisation volume : ratio vs moyenne
|
||||
vol_values = window['volume'].values.astype(np.float64)
|
||||
mean_vol = vol_values.mean()
|
||||
if mean_vol < 1e-10:
|
||||
result[:, 4] = 0.0
|
||||
else:
|
||||
result[:, 4] = vol_values / mean_vol
|
||||
|
||||
return result
|
||||
113
src/ml/cnn/cnn_model.py
Normal file
113
src/ml/cnn/cnn_model.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
CNN 1D pour la détection de patterns dans les bougies OHLCV.
|
||||
|
||||
Architecture Conv1d → BatchNorm → ReLU → Pool, empilée sur 3 couches,
|
||||
suivie d'un classifieur linéaire pour prédire LONG / SHORT / NEUTRAL.
|
||||
|
||||
Conçu pour capturer des patterns visuels (doji, engulfing, head&shoulders...)
|
||||
que le réseau apprend directement depuis les données brutes normalisées,
|
||||
sans features pré-calculées (contrairement à XGBoost).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
logger.warning("PyTorch non disponible — le CNN ne peut pas être utilisé")
|
||||
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
class TradingCNN(nn.Module):
|
||||
"""
|
||||
CNN 1D pour classification de séquences OHLCV.
|
||||
|
||||
Architecture :
|
||||
Input (batch, seq_len=64, 5)
|
||||
→ Permute → (batch, 5, 64)
|
||||
→ Conv1d(5→32, k=3) + BN + ReLU + MaxPool(2) → (batch, 32, 32)
|
||||
→ Conv1d(32→64, k=3) + BN + ReLU + MaxPool(2) → (batch, 64, 16)
|
||||
→ Conv1d(64→128, k=3) + BN + ReLU + AdaptiveAvgPool(1) → (batch, 128, 1)
|
||||
→ Flatten → Linear(128→64) + ReLU + Dropout(0.3)
|
||||
→ Linear(64→3) : logits [LONG, SHORT, NEUTRAL]
|
||||
|
||||
Args:
|
||||
n_features: Nombre de canaux d'entrée (5 = OHLCV)
|
||||
n_classes: Nombre de classes de sortie (3)
|
||||
dropout: Taux de dropout dans le classifieur
|
||||
"""
|
||||
|
||||
def __init__(self, n_features: int = 5, n_classes: int = 3, dropout: float = 0.3):
|
||||
super().__init__()
|
||||
|
||||
# Couches convolutives
|
||||
self.conv1 = nn.Conv1d(n_features, 32, kernel_size=3, padding=1)
|
||||
self.bn1 = nn.BatchNorm1d(32)
|
||||
self.pool1 = nn.MaxPool1d(2)
|
||||
|
||||
self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
|
||||
self.bn2 = nn.BatchNorm1d(64)
|
||||
self.pool2 = nn.MaxPool1d(2)
|
||||
|
||||
self.conv3 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
|
||||
self.bn3 = nn.BatchNorm1d(128)
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
# Classifieur
|
||||
self.fc1 = nn.Linear(128, 64)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.fc2 = nn.Linear(64, n_classes)
|
||||
|
||||
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Args:
|
||||
x: Tensor de shape (batch, seq_len, n_features)
|
||||
|
||||
Returns:
|
||||
Logits de shape (batch, n_classes)
|
||||
"""
|
||||
# Permute pour Conv1d : (batch, n_features, seq_len)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
# Bloc 1
|
||||
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
|
||||
# Bloc 2
|
||||
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
|
||||
# Bloc 3
|
||||
x = self.adaptive_pool(F.relu(self.bn3(self.conv3(x))))
|
||||
|
||||
# Flatten + classifieur
|
||||
x = x.squeeze(-1) # (batch, 128)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
|
||||
return x
|
||||
|
||||
def predict_proba(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||
"""
|
||||
Retourne les probabilités de chaque classe via Softmax.
|
||||
|
||||
Args:
|
||||
x: Tensor de shape (batch, seq_len, n_features)
|
||||
|
||||
Returns:
|
||||
Probabilités de shape (batch, n_classes)
|
||||
"""
|
||||
logits = self.forward(x)
|
||||
return F.softmax(logits, dim=1)
|
||||
|
||||
else:
|
||||
# Placeholder si PyTorch non disponible
|
||||
class TradingCNN:
|
||||
"""Placeholder — PyTorch non disponible."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("PyTorch non disponible — impossible de créer TradingCNN")
|
||||
585
src/ml/cnn/cnn_strategy_model.py
Normal file
585
src/ml/cnn/cnn_strategy_model.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
CNN Strategy Model — Modèle CNN 1D qui apprend des patterns de bougies.
|
||||
|
||||
Contrairement à MLStrategyModel (XGBoost sur features TA pré-calculées),
|
||||
ce modèle travaille directement sur les séquences OHLCV brutes normalisées.
|
||||
Le CNN détecte lui-même les patterns visuels pertinents (doji, engulfing, etc.).
|
||||
|
||||
Pipeline :
|
||||
1. Chargement données OHLCV
|
||||
2. Encodage séquences (CandlestickEncoder : z-score glissant)
|
||||
3. Génération labels (LabelGenerator — partagé avec MLStrategyModel)
|
||||
4. Entraînement CNN (PyTorch, Adam, CrossEntropy, early stopping)
|
||||
5. Walk-forward validation (2 folds temporels)
|
||||
6. Sauvegarde state_dict + métadonnées JSON
|
||||
|
||||
Usage:
|
||||
model = CNNStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv)
|
||||
signal = model.predict(df_recent)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.ml.features.label_generator import LabelGenerator
|
||||
from src.ml.cnn.candlestick_encoder import CandlestickEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de sauvegarde des modèles CNN
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_strategy"
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
from src.ml.cnn.cnn_model import TradingCNN
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
logger.warning("PyTorch non disponible — CNNStrategyModel ne peut pas fonctionner")
|
||||
|
||||
try:
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
SKLEARN_AVAILABLE = True
|
||||
except ImportError:
|
||||
SKLEARN_AVAILABLE = False
|
||||
|
||||
|
||||
# Mapping indices CNN → signaux de trading
|
||||
# CNN : 0=LONG, 1=SHORT, 2=NEUTRAL
|
||||
# Trading : 1=LONG, -1=SHORT, 0=NEUTRAL
|
||||
CLASS_MAP = {0: 1, 1: -1, 2: 0}
|
||||
|
||||
# Mapping inverse : labels LabelGenerator → indices CNN
|
||||
# LabelGenerator : 1=LONG, -1=SHORT, 0=NEUTRAL
|
||||
LABEL_TO_INDEX = {1: 0, -1: 1, 0: 2}
|
||||
|
||||
|
||||
class CNNStrategyModel:
|
||||
"""
|
||||
Modèle CNN qui apprend les patterns visuels des bougies.
|
||||
|
||||
Le modèle :
|
||||
- Travaille sur les 64 dernières bougies OHLCV (données brutes normalisées)
|
||||
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||
- Donne un score de confiance [0..1] par prédiction
|
||||
- Se sauvegarde sur disque (state_dict PyTorch + métadonnées JSON)
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h', '15m')
|
||||
seq_len: Longueur des séquences d'entrée
|
||||
min_confidence: Seuil de confiance pour signal tradeable
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (labels)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (labels)
|
||||
horizon: Nombre de barres pour évaluer TP/SL (labels)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
seq_len: int = 64,
|
||||
min_confidence: float = 0.55,
|
||||
tp_atr_mult: float = 2.0,
|
||||
sl_atr_mult: float = 1.0,
|
||||
horizon: int = 30,
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.model_type = 'cnn'
|
||||
self.seq_len = seq_len
|
||||
self.min_confidence = min_confidence
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.horizon = horizon
|
||||
|
||||
self.model = None
|
||||
self.is_trained = False
|
||||
self.metadata: Dict = {}
|
||||
self.encoder = CandlestickEncoder(seq_len=seq_len)
|
||||
|
||||
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement
|
||||
# -------------------------------------------------------------------------
|
||||
def train(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Entraîne le CNN sur les données OHLCV.
|
||||
|
||||
Utilise les mêmes labels que MLStrategyModel (LabelGenerator.generate_atr_based).
|
||||
Walk-forward validation sur 2 folds temporels.
|
||||
Sauvegarde automatiquement le modèle après entraînement.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV (au moins 200 barres recommandées)
|
||||
|
||||
Returns:
|
||||
Dict avec wf_metrics, label_dist, n_samples, trained_at
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible'}
|
||||
|
||||
logger.info(f"Début entraînement CNNStrategyModel pour {self.symbol}/{self.timeframe}")
|
||||
logger.info(f" Données : {len(data)} barres, seq_len={self.seq_len}")
|
||||
|
||||
# 1. Génération des labels (même méthode que MLStrategyModel)
|
||||
gen = LabelGenerator(horizon=self.horizon)
|
||||
labels_series = gen.generate_atr_based(
|
||||
data,
|
||||
atr_tp_mult=self.tp_atr_mult,
|
||||
atr_sl_mult=self.sl_atr_mult,
|
||||
)
|
||||
|
||||
# 2. Encodage des séquences OHLCV
|
||||
sequences = self.encoder.encode(data, seq_len=self.seq_len)
|
||||
|
||||
if len(sequences) == 0:
|
||||
return {'error': 'Pas assez de données pour encoder des séquences'}
|
||||
|
||||
# 3. Aligner labels sur les séquences
|
||||
# Chaque séquence[i] correspond aux barres [i : i+seq_len]
|
||||
# Le label associé est celui de la dernière barre de la séquence (i + seq_len - 1)
|
||||
n_samples = len(sequences)
|
||||
label_indices = []
|
||||
for i in range(n_samples):
|
||||
target_idx = i + self.seq_len - 1
|
||||
if target_idx < len(labels_series):
|
||||
label_indices.append(target_idx)
|
||||
else:
|
||||
label_indices.append(None)
|
||||
|
||||
# Filtrer les séquences valides (label disponible et pas en fin de horizon)
|
||||
valid_mask = []
|
||||
labels_aligned = []
|
||||
for i, idx in enumerate(label_indices):
|
||||
if idx is not None and idx < len(labels_series) - self.horizon:
|
||||
raw_label = labels_series.iloc[idx]
|
||||
labels_aligned.append(LABEL_TO_INDEX.get(raw_label, 2))
|
||||
valid_mask.append(i)
|
||||
|
||||
if len(valid_mask) < 50:
|
||||
return {'error': f'Trop peu de données valides : {len(valid_mask)} échantillons'}
|
||||
|
||||
X = sequences[valid_mask]
|
||||
y = np.array(labels_aligned, dtype=np.int64)
|
||||
|
||||
logger.info(f" {len(X)} échantillons après alignement")
|
||||
n_long = (y == 0).sum()
|
||||
n_short = (y == 1).sum()
|
||||
n_neutral = (y == 2).sum()
|
||||
logger.info(f" Distribution : LONG={n_long}, SHORT={n_short}, NEUTRAL={n_neutral}")
|
||||
|
||||
# 4. Walk-forward validation (2 folds)
|
||||
wf_metrics = self._walk_forward_eval(X, y, n_folds=2)
|
||||
|
||||
# 5. Entraînement final sur toutes les données
|
||||
self.model = TradingCNN(n_features=5, n_classes=3)
|
||||
class_weights = self._compute_class_weights(y)
|
||||
self._train_model(self.model, X, y, class_weights, max_epochs=100, patience=10)
|
||||
self.is_trained = True
|
||||
|
||||
# 6. Métadonnées
|
||||
self.metadata = {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'model_type': self.model_type,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'n_samples': len(X),
|
||||
'seq_len': self.seq_len,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'label_dist': {
|
||||
'long': int(n_long),
|
||||
'short': int(n_short),
|
||||
'neutral': int(n_neutral),
|
||||
},
|
||||
'wf_metrics': wf_metrics,
|
||||
}
|
||||
|
||||
# 7. Sauvegarde
|
||||
self.save()
|
||||
|
||||
logger.info(f"Entraînement CNN terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}")
|
||||
return self.metadata
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Prédiction
|
||||
# -------------------------------------------------------------------------
|
||||
def predict(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal pour les dernières barres.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV récent (au moins seq_len barres)
|
||||
|
||||
Returns:
|
||||
Dict : {
|
||||
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
|
||||
'confidence': float [0..1],
|
||||
'probas': {'long': float, 'short': float, 'neutral': float},
|
||||
'tradeable': bool (confidence >= min_confidence et signal != 0)
|
||||
}
|
||||
"""
|
||||
if not self.is_trained or self.model is None:
|
||||
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'Modèle non entraîné'}
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'PyTorch non disponible'}
|
||||
|
||||
try:
|
||||
# Encoder la dernière séquence
|
||||
seq = self.encoder.encode_last(data, seq_len=self.seq_len)
|
||||
if len(seq) == 0:
|
||||
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'Données insuffisantes'}
|
||||
|
||||
# Prédiction
|
||||
self.model.eval()
|
||||
x_tensor = torch.FloatTensor(seq)
|
||||
|
||||
with torch.no_grad():
|
||||
probas_tensor = self.model.predict_proba(x_tensor)
|
||||
probas_np = probas_tensor.numpy()[0]
|
||||
|
||||
# Mapping vers signaux de trading
|
||||
pred_idx = int(np.argmax(probas_np))
|
||||
signal = CLASS_MAP[pred_idx]
|
||||
confidence = float(probas_np[pred_idx])
|
||||
|
||||
probas = {
|
||||
'long': float(probas_np[0]),
|
||||
'short': float(probas_np[1]),
|
||||
'neutral': float(probas_np[2]),
|
||||
}
|
||||
|
||||
tradeable = confidence >= self.min_confidence and signal != 0
|
||||
|
||||
return {
|
||||
'signal': signal,
|
||||
'confidence': confidence,
|
||||
'probas': probas,
|
||||
'tradeable': tradeable,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur prédiction CNN : {e}")
|
||||
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': str(e)}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sauvegarde / Chargement
|
||||
# -------------------------------------------------------------------------
|
||||
def save(self) -> None:
|
||||
"""Sauvegarde le state_dict PyTorch + métadonnées JSON."""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.model is None:
|
||||
raise RuntimeError("Modèle non entraîné")
|
||||
|
||||
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_id = f"{self.symbol}_{self.timeframe}_cnn"
|
||||
model_path = MODELS_DIR / f"{model_id}.pt"
|
||||
meta_path = MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
# Sauvegarder state_dict PyTorch
|
||||
torch.save({
|
||||
'state_dict': self.model.state_dict(),
|
||||
'config': {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'seq_len': self.seq_len,
|
||||
'min_confidence': self.min_confidence,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
},
|
||||
}, model_path)
|
||||
|
||||
# Sauvegarder métadonnées JSON
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Modèle CNN sauvegardé : {model_path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'CNNStrategyModel':
|
||||
"""
|
||||
Charge un modèle CNN depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance CNNStrategyModel prête à prédire
|
||||
|
||||
Raises:
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
RuntimeError si PyTorch non disponible
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
model_id = f"{symbol}_{timeframe}_cnn"
|
||||
model_path = MODELS_DIR / f"{model_id}.pt"
|
||||
meta_path = MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle CNN non trouvé : {model_path}")
|
||||
|
||||
# Charger le checkpoint
|
||||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
cfg = checkpoint.get('config', {})
|
||||
|
||||
instance = cls(
|
||||
symbol=cfg.get('symbol', symbol),
|
||||
timeframe=cfg.get('timeframe', timeframe),
|
||||
seq_len=cfg.get('seq_len', 64),
|
||||
min_confidence=cfg.get('min_confidence', 0.55),
|
||||
tp_atr_mult=cfg.get('tp_atr_mult', 2.0),
|
||||
sl_atr_mult=cfg.get('sl_atr_mult', 1.0),
|
||||
horizon=cfg.get('horizon', 30),
|
||||
)
|
||||
|
||||
# Reconstruire le modèle et charger les poids
|
||||
instance.model = TradingCNN(n_features=5, n_classes=3)
|
||||
instance.model.load_state_dict(checkpoint['state_dict'])
|
||||
instance.model.eval()
|
||||
instance.is_trained = True
|
||||
|
||||
# Charger métadonnées si disponibles
|
||||
if meta_path.exists():
|
||||
try:
|
||||
with open(meta_path) as f:
|
||||
instance.metadata = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Modèle CNN chargé depuis {model_path}")
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def list_trained_models(cls) -> List[Dict]:
|
||||
"""Liste les modèles CNN entraînés disponibles."""
|
||||
if not MODELS_DIR.exists():
|
||||
return []
|
||||
|
||||
models = []
|
||||
for f in MODELS_DIR.glob("*_meta.json"):
|
||||
try:
|
||||
with open(f) as fp:
|
||||
meta = json.load(fp)
|
||||
models.append({
|
||||
'symbol': meta.get('symbol', '?'),
|
||||
'timeframe': meta.get('timeframe', '?'),
|
||||
'model_type': 'cnn',
|
||||
'trained_at': meta.get('trained_at', '?'),
|
||||
'n_samples': meta.get('n_samples', 0),
|
||||
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return models
|
||||
|
||||
def get_feature_importance(self, top_n: int = 10) -> List[Dict]:
|
||||
"""
|
||||
Pour CNN : pas de feature importance classique.
|
||||
|
||||
Retourne une liste vide — les CNN n'ont pas d'importance par feature
|
||||
au sens des arbres de décision. Une analyse par gradient (GradCAM)
|
||||
serait possible mais hors scope pour l'instant.
|
||||
"""
|
||||
return []
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Walk-forward évaluation
|
||||
# -------------------------------------------------------------------------
|
||||
def _walk_forward_eval(self, X: np.ndarray, y: np.ndarray, n_folds: int = 2) -> Dict:
|
||||
"""
|
||||
Évalue le CNN en walk-forward validation temporelle.
|
||||
|
||||
Découpage : train 60%, test 20%, hold-out 20% (sur 2 folds).
|
||||
"""
|
||||
n = len(X)
|
||||
fold_size = n // (n_folds + 1)
|
||||
accuracies, precisions, recalls = [], [], []
|
||||
|
||||
for fold in range(n_folds):
|
||||
train_end = fold_size * (fold + 1)
|
||||
test_end = train_end + fold_size
|
||||
|
||||
if test_end > n:
|
||||
break
|
||||
|
||||
X_tr, y_tr = X[:train_end], y[:train_end]
|
||||
X_te, y_te = X[train_end:test_end], y[train_end:test_end]
|
||||
|
||||
if len(X_tr) < 30 or len(X_te) < 10:
|
||||
logger.warning(f" Fold {fold + 1} ignoré : pas assez de données")
|
||||
continue
|
||||
|
||||
# Entraîner un modèle temporaire
|
||||
model = TradingCNN(n_features=5, n_classes=3)
|
||||
class_weights = self._compute_class_weights(y_tr)
|
||||
self._train_model(model, X_tr, y_tr, class_weights, max_epochs=50, patience=5)
|
||||
|
||||
# Évaluer
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
x_tensor = torch.FloatTensor(X_te)
|
||||
logits = model(x_tensor)
|
||||
y_pred = logits.argmax(dim=1).numpy()
|
||||
|
||||
acc = (y_pred == y_te).mean()
|
||||
|
||||
if SKLEARN_AVAILABLE:
|
||||
prec, rec, _, _ = precision_recall_fscore_support(
|
||||
y_te, y_pred, average='macro', zero_division=0
|
||||
)
|
||||
else:
|
||||
prec, rec = 0.0, 0.0
|
||||
|
||||
accuracies.append(acc)
|
||||
precisions.append(prec)
|
||||
recalls.append(rec)
|
||||
logger.info(f" Fold {fold + 1}/{n_folds} : acc={acc:.2%}, prec={prec:.2%}, rec={rec:.2%}")
|
||||
|
||||
if not accuracies:
|
||||
return {'avg_accuracy': 0.0, 'avg_precision': 0.0, 'avg_recall': 0.0, 'fold_accuracies': []}
|
||||
|
||||
return {
|
||||
'avg_accuracy': float(np.mean(accuracies)),
|
||||
'avg_precision': float(np.mean(precisions)),
|
||||
'avg_recall': float(np.mean(recalls)),
|
||||
'fold_accuracies': [float(a) for a in accuracies],
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement interne
|
||||
# -------------------------------------------------------------------------
|
||||
def _train_model(
|
||||
self,
|
||||
model: 'nn.Module',
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
class_weights: 'torch.Tensor' = None,
|
||||
max_epochs: int = 100,
|
||||
patience: int = 10,
|
||||
lr: float = 1e-3,
|
||||
batch_size: int = 64,
|
||||
) -> None:
|
||||
"""
|
||||
Boucle d'entraînement PyTorch avec early stopping.
|
||||
|
||||
Args:
|
||||
model: Instance TradingCNN
|
||||
X: Séquences (N, seq_len, 5)
|
||||
y: Labels (N,) — indices 0, 1, 2
|
||||
class_weights: Poids par classe pour CrossEntropy
|
||||
max_epochs: Nombre max d'époques
|
||||
patience: Nombre d'époques sans amélioration avant arrêt
|
||||
lr: Learning rate
|
||||
batch_size: Taille des batchs
|
||||
"""
|
||||
# Séparer un petit set de validation (derniers 15%)
|
||||
val_size = max(int(len(X) * 0.15), 10)
|
||||
X_train, X_val = X[:-val_size], X[-val_size:]
|
||||
y_train, y_val = y[:-val_size], y[-val_size:]
|
||||
|
||||
# Tenseurs
|
||||
X_train_t = torch.FloatTensor(X_train)
|
||||
y_train_t = torch.LongTensor(y_train)
|
||||
X_val_t = torch.FloatTensor(X_val)
|
||||
y_val_t = torch.LongTensor(y_val)
|
||||
|
||||
dataset = TensorDataset(X_train_t, y_train_t)
|
||||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Loss avec poids de classe
|
||||
if class_weights is not None:
|
||||
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
# Early stopping
|
||||
best_val_loss = float('inf')
|
||||
best_state = None
|
||||
epochs_without_improvement = 0
|
||||
|
||||
model.train()
|
||||
for epoch in range(max_epochs):
|
||||
# Phase entraînement
|
||||
total_loss = 0.0
|
||||
n_batches = 0
|
||||
for x_batch, y_batch in loader:
|
||||
optimizer.zero_grad()
|
||||
logits = model(x_batch)
|
||||
loss = criterion(logits, y_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
n_batches += 1
|
||||
|
||||
# Phase validation
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_logits = model(X_val_t)
|
||||
val_loss = criterion(val_logits, y_val_t).item()
|
||||
model.train()
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
||||
epochs_without_improvement = 0
|
||||
else:
|
||||
epochs_without_improvement += 1
|
||||
|
||||
if epochs_without_improvement >= patience:
|
||||
logger.info(f" Early stopping à l'époque {epoch + 1} (patience={patience})")
|
||||
break
|
||||
|
||||
if (epoch + 1) % 20 == 0:
|
||||
avg_loss = total_loss / max(n_batches, 1)
|
||||
logger.info(f" Époque {epoch + 1}/{max_epochs} — loss={avg_loss:.4f}, val_loss={val_loss:.4f}")
|
||||
|
||||
# Restaurer les meilleurs poids
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
model.eval()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# -------------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor':
|
||||
"""
|
||||
Calcule les poids inversement proportionnels à la fréquence des classes.
|
||||
|
||||
Permet de compenser le déséquilibre (ex: trop de NEUTRAL).
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return None
|
||||
|
||||
classes, counts = np.unique(y, return_counts=True)
|
||||
n_samples = len(y)
|
||||
n_classes = 3
|
||||
|
||||
weights = np.ones(n_classes, dtype=np.float32)
|
||||
for cls, count in zip(classes, counts):
|
||||
if cls < n_classes:
|
||||
weights[int(cls)] = n_samples / (n_classes * count)
|
||||
|
||||
return torch.FloatTensor(weights)
|
||||
11
src/ml/cnn_image/__init__.py
Normal file
11
src/ml/cnn_image/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Module CNN Image-Based — Analyse visuelle de graphiques en chandeliers.
|
||||
|
||||
Ce module implémente un CNN Conv2D qui analyse des graphiques de chandeliers
|
||||
rendus comme de vraies images (128×128 RGB), pour reconnaître les patterns
|
||||
visuels exactement comme un trader humain devant TradingView.
|
||||
"""
|
||||
|
||||
from src.ml.cnn_image.cnn_image_strategy_model import CNNImageStrategyModel
|
||||
|
||||
__all__ = ['CNNImageStrategyModel']
|
||||
472
src/ml/cnn_image/chart_renderer.py
Normal file
472
src/ml/cnn_image/chart_renderer.py
Normal file
@@ -0,0 +1,472 @@
|
||||
"""
|
||||
CandlestickImageRenderer — Convertit des données OHLCV en images de graphiques.
|
||||
|
||||
Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB
|
||||
qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns.
|
||||
|
||||
Rendu (pur numpy, très rapide) :
|
||||
- Fond noir (#0d1117), style TradingView
|
||||
- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse
|
||||
- Mèches (high/low), corps (open/close), volume en bas (20% de hauteur)
|
||||
- Pas d'axes, pas de labels, pas de titre
|
||||
- Taille fixe : 128×128 pixels, 3 canaux RGB
|
||||
|
||||
Perf : ~5 000 images/s (vs ~5/s avec mplfinance).
|
||||
mplfinance reste utilisable via _render_with_mplfinance() pour l'affichage.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Couleurs TradingView (normalisées [0, 1])
|
||||
BG_COLOR = np.array([13 / 255, 17 / 255, 23 / 255], dtype=np.float32) # #0d1117
|
||||
GREEN_COLOR = np.array([38 / 255, 166 / 255, 154 / 255], dtype=np.float32) # #26a69a
|
||||
RED_COLOR = np.array([239 / 255, 83 / 255, 80 / 255], dtype=np.float32) # #ef5350
|
||||
|
||||
IMAGE_SIZE = 128
|
||||
VOLUME_RATIO = 0.20 # 20% de la hauteur pour le volume
|
||||
|
||||
# --- Détection optionnelle de mplfinance (pour affichage uniquement) ---
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
MPLFINANCE_AVAILABLE = True
|
||||
except ImportError:
|
||||
MPLFINANCE_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
PIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIL_AVAILABLE = False
|
||||
|
||||
|
||||
class CandlestickImageRenderer:
|
||||
"""
|
||||
Convertit des données OHLCV en images de graphiques en chandeliers.
|
||||
|
||||
Chaque image est un instantané visuel de `seq_len` bougies consécutives,
|
||||
rendu via pur numpy (rapide, ~5 000 images/s) ou mplfinance (lent, haute qualité).
|
||||
Le résultat est normalisé en float32 dans [0, 1].
|
||||
|
||||
Args:
|
||||
image_size: Taille carrée de l'image en pixels (défaut 128)
|
||||
use_mplfinance: Forcer mplfinance même pour encode() (lent, déconseillé)
|
||||
"""
|
||||
|
||||
def __init__(self, image_size: int = IMAGE_SIZE, use_mplfinance: bool = False):
|
||||
self.image_size = image_size
|
||||
self.use_mplfinance = use_mplfinance and MPLFINANCE_AVAILABLE and PIL_AVAILABLE
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface publique
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def encode(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||
"""
|
||||
Encode toutes les fenêtres glissantes du DataFrame en images.
|
||||
|
||||
Produit N = len(df) - seq_len fenêtres, chacune rendue en image
|
||||
128×128 RGB normalisée [0, 1].
|
||||
|
||||
Utilise le renderer pur numpy (rapide) par défaut.
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV avec colonnes open/high/low/close/volume
|
||||
seq_len: Nombre de bougies par fenêtre (défaut 64)
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (N, 3, 128, 128), dtype float32, valeurs [0, 1]
|
||||
Retourne un tableau vide (0, 3, 128, 128) si df est trop court.
|
||||
"""
|
||||
df = self._prepare_df(df)
|
||||
n_windows = len(df) - seq_len
|
||||
|
||||
if n_windows <= 0:
|
||||
logger.warning(
|
||||
f"DataFrame trop court ({len(df)} barres) pour seq_len={seq_len}"
|
||||
)
|
||||
return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32)
|
||||
|
||||
# Extraction des valeurs OHLCV en numpy (une seule conversion)
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
|
||||
|
||||
if self.use_mplfinance:
|
||||
return self._encode_mplfinance(df, seq_len, n_windows)
|
||||
else:
|
||||
return self._encode_numpy(ohlcv, seq_len, n_windows)
|
||||
|
||||
def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||
"""
|
||||
Encode uniquement la dernière fenêtre du DataFrame.
|
||||
|
||||
Utilisé pour la prédiction en temps réel : retourne (1, 3, 128, 128).
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV
|
||||
seq_len: Nombre de bougies pour la dernière fenêtre (défaut 64)
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (1, 3, 128, 128), dtype float32, valeurs [0, 1]
|
||||
|
||||
Raises:
|
||||
ValueError si df est trop court pour seq_len bougies
|
||||
"""
|
||||
df = self._prepare_df(df)
|
||||
|
||||
if len(df) < seq_len:
|
||||
raise ValueError(
|
||||
f"DataFrame insuffisant : {len(df)} barres < seq_len={seq_len}"
|
||||
)
|
||||
|
||||
window = df.iloc[-seq_len:]
|
||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
|
||||
img = self._render_numpy(ohlcv)
|
||||
return img[np.newaxis, ...] # (1, 3, H, W)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Renderer pur numpy (rapide)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _encode_numpy(
|
||||
self,
|
||||
ohlcv: np.ndarray,
|
||||
seq_len: int,
|
||||
n_windows: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode toutes les fenêtres en batch vectorisé (sans boucle Python sur N).
|
||||
|
||||
Utilise sliding_window_view pour créer toutes les fenêtres d'un coup,
|
||||
puis boucle sur les seq_len bougies (64) en traitant toutes les N fenêtres
|
||||
simultanément. Numpy libère le GIL pendant les opérations vectorisées,
|
||||
ce qui préserve la réactivité de l'event loop FastAPI.
|
||||
|
||||
Args:
|
||||
ohlcv: (T, 5) float32 [open, high, low, close, volume]
|
||||
seq_len: Longueur de chaque fenêtre
|
||||
n_windows: Nombre total de fenêtres à générer
|
||||
|
||||
Returns:
|
||||
(N, 3, H, W) float32
|
||||
"""
|
||||
H = W = self.image_size
|
||||
|
||||
# --- Toutes les fenêtres d'un coup : (N, seq_len, 5) ---
|
||||
# sliding_window_view est O(1) en mémoire (vue sans copie)
|
||||
windows = np.lib.stride_tricks.sliding_window_view(ohlcv, (seq_len, 5))
|
||||
windows = windows[:n_windows, 0, :, :].astype(np.float32) # (N, seq_len, 5)
|
||||
|
||||
opens = windows[:, :, 0] # (N, seq_len)
|
||||
highs = windows[:, :, 1]
|
||||
lows = windows[:, :, 2]
|
||||
closes = windows[:, :, 3]
|
||||
vols = windows[:, :, 4]
|
||||
|
||||
# --- Fond de toutes les images ---
|
||||
images = np.empty((n_windows, 3, H, W), dtype=np.float32)
|
||||
images[:, 0, :, :] = BG_COLOR[0]
|
||||
images[:, 1, :, :] = BG_COLOR[1]
|
||||
images[:, 2, :, :] = BG_COLOR[2]
|
||||
|
||||
# --- Normalisation des prix par fenêtre ---
|
||||
price_min = lows.min(axis=1, keepdims=True) # (N, 1)
|
||||
price_max = highs.max(axis=1, keepdims=True) # (N, 1)
|
||||
price_rng = np.maximum(price_max - price_min, 1e-8)
|
||||
|
||||
vol_h = max(1, int(H * VOLUME_RATIO))
|
||||
chart_h = H - vol_h
|
||||
|
||||
def to_row(prices: np.ndarray) -> np.ndarray:
|
||||
"""Convertit prix → rangée pixel (0 = haut, chart_h-1 = bas)."""
|
||||
norm = (prices - price_min) / price_rng
|
||||
rows = ((1.0 - norm) * (chart_h - 1)).astype(np.int32)
|
||||
return np.clip(rows, 0, chart_h - 1)
|
||||
|
||||
row_h = to_row(highs) # (N, seq_len)
|
||||
row_l = to_row(lows)
|
||||
row_o = to_row(opens)
|
||||
row_c = to_row(closes)
|
||||
|
||||
body_top = np.minimum(row_o, row_c) # (N, seq_len)
|
||||
body_bot = np.maximum(row_o, row_c)
|
||||
|
||||
is_bull = (closes >= opens) # (N, seq_len) bool
|
||||
|
||||
# Volume normalisé [0, 1] par fenêtre
|
||||
vol_max = np.maximum(vols.max(axis=1, keepdims=True), 1e-8)
|
||||
vol_norm = vols / vol_max # (N, seq_len)
|
||||
|
||||
# Positions X des bougies
|
||||
candle_w = max(1, W // seq_len)
|
||||
half_w = max(1, candle_w // 2)
|
||||
x_centers = ((np.arange(seq_len) + 0.5) * W / seq_len).astype(np.int32)
|
||||
|
||||
row_idx = np.arange(chart_h) # (chart_h,) pour les masques
|
||||
vol_idx = np.arange(vol_h) # (vol_h,)
|
||||
|
||||
# --- Boucle sur les bougies (64 itérations, GIL libéré entre chaque) ---
|
||||
for i in range(seq_len):
|
||||
x_c = int(x_centers[i])
|
||||
x0 = max(0, x_c - half_w)
|
||||
x1 = min(W, x_c + half_w + 1)
|
||||
wick_x = min(x_c, W - 1)
|
||||
|
||||
rh = row_h[:, i] # (N,)
|
||||
rl = row_l[:, i]
|
||||
bt = body_top[:, i]
|
||||
bb = body_bot[:, i]
|
||||
bull = is_bull[:, i] # (N,) bool
|
||||
|
||||
# Masques vectorisés sur toutes les N fenêtres
|
||||
# wick_mask[w, r] = True si rh[w] <= r <= rl[w]
|
||||
wick_mask = (
|
||||
(row_idx[None, :] >= rh[:, None]) &
|
||||
(row_idx[None, :] <= rl[:, None])
|
||||
) # (N, chart_h)
|
||||
|
||||
body_mask = (
|
||||
(row_idx[None, :] >= bt[:, None]) &
|
||||
(row_idx[None, :] <= bb[:, None])
|
||||
) # (N, chart_h)
|
||||
|
||||
# Volume : barre du bas
|
||||
vol_bar_h = (vol_norm[:, i] * vol_h).astype(np.int32)
|
||||
vol_thresh = np.maximum(vol_h - vol_bar_h, 0)
|
||||
vol_mask = vol_idx[None, :] >= vol_thresh[:, None] # (N, vol_h)
|
||||
|
||||
for c in range(3):
|
||||
g = float(GREEN_COLOR[c])
|
||||
r = float(RED_COLOR[c])
|
||||
# Couleur par fenêtre : vert si haussière, rouge sinon
|
||||
color_w = np.where(bull, g, r).astype(np.float32) # (N,)
|
||||
|
||||
# Mèche
|
||||
images[:, c, :chart_h, wick_x] = np.where(
|
||||
wick_mask,
|
||||
color_w[:, None],
|
||||
images[:, c, :chart_h, wick_x],
|
||||
)
|
||||
|
||||
# Corps
|
||||
images[:, c, :chart_h, x0:x1] = np.where(
|
||||
body_mask[:, :, None],
|
||||
color_w[:, None, None],
|
||||
images[:, c, :chart_h, x0:x1],
|
||||
)
|
||||
|
||||
# Volume (opacité 60%)
|
||||
images[:, c, H - vol_h:H, x0:x1] = np.where(
|
||||
vol_mask[:, :, None],
|
||||
color_w[:, None, None] * 0.6,
|
||||
images[:, c, H - vol_h:H, x0:x1],
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
def _render_numpy(self, ohlcv: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rend une seule fenêtre OHLCV en image (3, H, W) via pur numpy.
|
||||
|
||||
Dessin :
|
||||
- Fond #0d1117
|
||||
- Mèche : ligne verticale high→low (1 pixel de large)
|
||||
- Corps : rectangle open→close (largeur proportionnelle)
|
||||
- Volume : barre en bas (20% hauteur), opacité 60%
|
||||
|
||||
Args:
|
||||
ohlcv: (N, 5) float32 [open, high, low, close, volume]
|
||||
|
||||
Returns:
|
||||
(3, image_size, image_size) float32 [0, 1]
|
||||
"""
|
||||
H = W = self.image_size
|
||||
n = len(ohlcv)
|
||||
|
||||
if n == 0:
|
||||
return np.zeros((3, H, W), dtype=np.float32)
|
||||
|
||||
# --- Fond ---
|
||||
img = np.empty((3, H, W), dtype=np.float32)
|
||||
for c in range(3):
|
||||
img[c] = BG_COLOR[c]
|
||||
|
||||
opens = ohlcv[:, 0]
|
||||
highs = ohlcv[:, 1]
|
||||
lows = ohlcv[:, 2]
|
||||
closes = ohlcv[:, 3]
|
||||
vols = ohlcv[:, 4]
|
||||
|
||||
# --- Normalisation des prix ---
|
||||
price_min = lows.min()
|
||||
price_max = highs.max()
|
||||
if price_max <= price_min:
|
||||
return img # données dégénérées
|
||||
|
||||
vol_h = max(1, int(H * VOLUME_RATIO))
|
||||
chart_h = H - vol_h # hauteur zone prix
|
||||
|
||||
def price_to_row(price: float) -> int:
|
||||
"""Convertit un prix en ligne pixel (0 = haut, chart_h-1 = bas)."""
|
||||
norm = (price - price_min) / (price_max - price_min)
|
||||
row = int((1.0 - norm) * (chart_h - 1))
|
||||
return max(0, min(chart_h - 1, row))
|
||||
|
||||
# --- Largeur des bougies ---
|
||||
candle_w = max(1, W // n)
|
||||
half_w = max(1, candle_w // 2)
|
||||
|
||||
# --- Volume ---
|
||||
vol_max = vols.max()
|
||||
|
||||
for i in range(n):
|
||||
x_c = int((i + 0.5) * W / n)
|
||||
x0 = max(0, x_c - half_w)
|
||||
x1 = min(W, x_c + half_w + 1)
|
||||
|
||||
is_bull = closes[i] >= opens[i]
|
||||
color = GREEN_COLOR if is_bull else RED_COLOR
|
||||
|
||||
row_h = price_to_row(highs[i])
|
||||
row_l = price_to_row(lows[i])
|
||||
row_o = price_to_row(opens[i])
|
||||
row_c = price_to_row(closes[i])
|
||||
|
||||
body_top = min(row_o, row_c)
|
||||
body_bot = max(row_o, row_c)
|
||||
|
||||
# Mèche (1 pixel de large, centré)
|
||||
wick_x = min(x_c, W - 1)
|
||||
for c in range(3):
|
||||
img[c, row_h: row_l + 1, wick_x] = color[c]
|
||||
|
||||
# Corps
|
||||
if body_bot >= body_top:
|
||||
for c in range(3):
|
||||
img[c, body_top: body_bot + 1, x0: x1] = color[c]
|
||||
|
||||
# Volume (bas de l'image)
|
||||
if vol_max > 0:
|
||||
bar_h = int((vols[i] / vol_max) * vol_h)
|
||||
if bar_h > 0:
|
||||
row_v0 = H - bar_h
|
||||
for c in range(3):
|
||||
img[c, row_v0: H, x0: x1] = color[c] * 0.6
|
||||
|
||||
return img
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Renderer mplfinance (lent, haute qualité, pour affichage uniquement)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _encode_mplfinance(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
seq_len: int,
|
||||
n_windows: int,
|
||||
) -> np.ndarray:
|
||||
"""Encode via mplfinance (lent — ne pas utiliser pour l'entraînement)."""
|
||||
H = W = self.image_size
|
||||
images = np.zeros((n_windows, 3, H, W), dtype=np.float32)
|
||||
import warnings
|
||||
for i in range(n_windows):
|
||||
window = df.iloc[i: i + seq_len]
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
images[i] = self._render_with_mplfinance(window)
|
||||
except Exception as e:
|
||||
logger.debug(f"Erreur mplfinance fenêtre {i} : {e}")
|
||||
return images
|
||||
|
||||
def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Rendu haute qualité via mplfinance (pour affichage / debug).
|
||||
|
||||
Raises:
|
||||
ImportError si mplfinance ou PIL non disponible.
|
||||
"""
|
||||
if not MPLFINANCE_AVAILABLE or not PIL_AVAILABLE:
|
||||
raise ImportError("mplfinance ou PIL non disponible")
|
||||
|
||||
style = mpf.make_mpf_style(
|
||||
base_mpf_style='nightclouds',
|
||||
marketcolors=mpf.make_marketcolors(
|
||||
up='#26a69a', down='#ef5350',
|
||||
wick={'up': '#26a69a', 'down': '#ef5350'},
|
||||
edge={'up': '#26a69a', 'down': '#ef5350'},
|
||||
volume={'up': '#26a69a', 'down': '#ef5350'},
|
||||
),
|
||||
facecolor='#0d1117',
|
||||
figcolor='#0d1117',
|
||||
gridcolor='#0d1117',
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
fig, axes = mpf.plot(
|
||||
df_window,
|
||||
type='candle',
|
||||
style=style,
|
||||
volume=True,
|
||||
axisoff=True,
|
||||
tight_layout=True,
|
||||
returnfig=True,
|
||||
figsize=(1.28, 1.28),
|
||||
)
|
||||
|
||||
for ax in axes:
|
||||
ax.set_axis_off()
|
||||
ax.set_facecolor('#0d1117')
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
fig.savefig(buf, format='png', dpi=100, bbox_inches='tight',
|
||||
pad_inches=0, facecolor='#0d1117')
|
||||
plt.close(fig)
|
||||
|
||||
buf.seek(0)
|
||||
pil_img = Image.open(buf).convert('RGB')
|
||||
pil_img = pil_img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
|
||||
arr = np.array(pil_img, dtype=np.float32) / 255.0
|
||||
return arr.transpose(2, 0, 1) # HWC → CHW
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex.
|
||||
"""
|
||||
df = df.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
try:
|
||||
df.index = pd.to_datetime(df.index)
|
||||
except Exception:
|
||||
df.index = pd.date_range(
|
||||
start='2020-01-01', periods=len(df), freq='1h'
|
||||
)
|
||||
|
||||
required = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in required:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
|
||||
df = df[required].dropna(subset=['open', 'high', 'low', 'close'])
|
||||
df = df.ffill().bfill()
|
||||
return df
|
||||
163
src/ml/cnn_image/cnn_image_model.py
Normal file
163
src/ml/cnn_image/cnn_image_model.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
CandlestickCNN — Réseau de neurones convolutif pour l'analyse visuelle de graphiques.
|
||||
|
||||
Architecture Conv2D 4-blocs conçue pour reconnaître les patterns visuels dans
|
||||
des images de chandeliers 128×128 RGB (bougies, volumes, supports/résistances).
|
||||
|
||||
Input : (batch, 3, 128, 128) — images RGB normalisées [0, 1]
|
||||
Output : (batch, 3) — logits pour 3 classes : SHORT(0), NEUTRAL(1), LONG(2)
|
||||
|
||||
Classes encodées :
|
||||
0 → SHORT (-1)
|
||||
1 → NEUTRAL ( 0)
|
||||
2 → LONG (+1)
|
||||
|
||||
L'encodage +1 est cohérent avec MLStrategyModel et CNNStrategyModel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Détection optionnelle de PyTorch ---
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
TORCH_AVAILABLE = True
|
||||
logger.debug("PyTorch disponible — CandlestickCNN activé")
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
# Stubs pour permettre l'import du module sans PyTorch
|
||||
nn = None
|
||||
torch = None
|
||||
logger.warning("PyTorch non disponible — CandlestickCNN désactivé")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Définition conditionnelle du modèle
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
|
||||
class CandlestickCNN(nn.Module):
|
||||
"""
|
||||
CNN vision pour graphiques chandeliers 128×128 RGB.
|
||||
|
||||
Architecture :
|
||||
Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm2d(32) + ReLU + MaxPool2d(2) → (32, 64, 64)
|
||||
Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm2d(64) + ReLU + MaxPool2d(2) → (64, 32, 32)
|
||||
Bloc 3 : Conv2d(64→128,3×3) + BatchNorm2d(128) + ReLU + MaxPool2d(2) → (128, 16, 16)
|
||||
Bloc 4 : Conv2d(128→256,3×3)+ BatchNorm2d(256) + ReLU + AdaptiveAvgPool2d(1) → (256,)
|
||||
|
||||
Classifieur :
|
||||
Linear(256→128) + Dropout(0.4) + ReLU
|
||||
Linear(128→3)
|
||||
|
||||
Output : logits bruts (3 classes) — appliquer softmax pour les probabilités.
|
||||
|
||||
Args:
|
||||
n_classes: Nombre de classes de sortie (défaut 3 : SHORT/NEUTRAL/LONG)
|
||||
dropout: Taux de dropout dans le classifieur (défaut 0.4)
|
||||
"""
|
||||
|
||||
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
|
||||
super(CandlestickCNN, self).__init__()
|
||||
|
||||
# --- Bloc convolutif 1 : (3, 128, 128) → (32, 64, 64) ---
|
||||
self.bloc1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
# --- Bloc convolutif 2 : (32, 64, 64) → (64, 32, 32) ---
|
||||
self.bloc2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
# --- Bloc convolutif 3 : (64, 32, 32) → (128, 16, 16) ---
|
||||
self.bloc3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
# --- Bloc convolutif 4 : (128, 16, 16) → (256, 1, 1) → (256,) ---
|
||||
self.bloc4 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AdaptiveAvgPool2d(output_size=1), # Global average pooling → (256, 1, 1)
|
||||
)
|
||||
|
||||
# --- Classifieur dense ---
|
||||
self.classifieur = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(128, n_classes),
|
||||
)
|
||||
|
||||
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||
"""
|
||||
Passe avant du réseau.
|
||||
|
||||
Args:
|
||||
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
|
||||
|
||||
Returns:
|
||||
Tensor (batch, 3) — logits bruts
|
||||
"""
|
||||
x = self.bloc1(x) # (batch, 32, 64, 64)
|
||||
x = self.bloc2(x) # (batch, 64, 32, 32)
|
||||
x = self.bloc3(x) # (batch, 128, 16, 16)
|
||||
x = self.bloc4(x) # (batch, 256, 1, 1)
|
||||
x = x.flatten(1) # (batch, 256)
|
||||
x = self.classifieur(x) # (batch, 3)
|
||||
return x
|
||||
|
||||
def predict_proba(self, x: 'torch.Tensor') -> np.ndarray:
|
||||
"""
|
||||
Calcule les probabilités softmax pour chaque classe.
|
||||
|
||||
Args:
|
||||
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (batch, 3), valeurs dans [0, 1], somme = 1
|
||||
Ordre des colonnes : [P_SHORT, P_NEUTRAL, P_LONG]
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
logits = self.forward(x) # (batch, 3)
|
||||
probas = F.softmax(logits, dim=1) # (batch, 3)
|
||||
return probas.cpu().numpy().astype(np.float32)
|
||||
|
||||
else:
|
||||
# Stub si PyTorch n'est pas disponible
|
||||
class CandlestickCNN:
|
||||
"""
|
||||
Stub de CandlestickCNN — PyTorch non disponible.
|
||||
|
||||
Toutes les méthodes lèvent une RuntimeError explicite.
|
||||
"""
|
||||
|
||||
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
|
||||
raise RuntimeError(
|
||||
"PyTorch non disponible. Installer torch>=2.0.0 pour utiliser CandlestickCNN."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
raise RuntimeError("PyTorch non disponible.")
|
||||
|
||||
def predict_proba(self, x):
|
||||
raise RuntimeError("PyTorch non disponible.")
|
||||
657
src/ml/cnn_image/cnn_image_strategy_model.py
Normal file
657
src/ml/cnn_image/cnn_image_strategy_model.py
Normal file
@@ -0,0 +1,657 @@
|
||||
"""
|
||||
CNNImageStrategyModel — Modèle CNN image-based pour la prédiction de signaux de trading.
|
||||
|
||||
Ce module entraîne un CNN Conv2D (CandlestickCNN) sur des images de graphiques
|
||||
en chandeliers pour prédire LONG (+1), NEUTRAL (0) ou SHORT (-1).
|
||||
|
||||
Pipeline :
|
||||
1. Génération de labels ATR-based (LabelGenerator partagé)
|
||||
2. Rendu des fenêtres OHLCV en images 128×128 RGB (CandlestickImageRenderer)
|
||||
3. Walk-forward (2 folds temporels) pour évaluation out-of-sample
|
||||
4. Entraînement sur toutes les données (Adam, CrossEntropyLoss weighted)
|
||||
5. Sauvegarde : EURUSD_1h.pt (state_dict) + EURUSD_1h_meta.json
|
||||
|
||||
Interface identique à MLStrategyModel et CNNStrategyModel :
|
||||
model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv)
|
||||
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||
model.save()
|
||||
model.load(symbol, timeframe)
|
||||
model.list_trained_models()
|
||||
model.get_feature_importance() # retourne [] (CNN = boîte noire)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.ml.features.label_generator import LabelGenerator
|
||||
from src.ml.cnn_image.chart_renderer import CandlestickImageRenderer, MPLFINANCE_AVAILABLE
|
||||
from src.ml.cnn_image.cnn_image_model import CandlestickCNN, TORCH_AVAILABLE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de sauvegarde des modèles CNN image
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_image_strategy"
|
||||
|
||||
# Décodage des classes encodées vers les signaux de trading
|
||||
# Classe 0 → SHORT (-1), Classe 1 → NEUTRAL (0), Classe 2 → LONG (+1)
|
||||
CLASS_MAP = {0: -1, 1: 0, 2: 1}
|
||||
|
||||
# Imports conditionnels PyTorch
|
||||
if TORCH_AVAILABLE:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
SKLEARN_AVAILABLE = True
|
||||
else:
|
||||
torch = None
|
||||
nn = None
|
||||
DataLoader = None
|
||||
TensorDataset = None
|
||||
TimeSeriesSplit = None
|
||||
precision_recall_fscore_support = None
|
||||
SKLEARN_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
SKLEARN_AVAILABLE = True
|
||||
except ImportError:
|
||||
SKLEARN_AVAILABLE = False
|
||||
|
||||
|
||||
class CNNImageStrategyModel:
|
||||
"""
|
||||
Modèle CNN image-based qui reconnaît les patterns visuels de trading.
|
||||
|
||||
Le modèle :
|
||||
- Convertit des bougies OHLCV en images 128×128 RGB (mplfinance)
|
||||
- Reconnaît visuellement : marteaux, doji, double top/bottom, supports/résistances
|
||||
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||
- Donne un score de confiance [0..1] par prédiction
|
||||
- Se sauvegarde en format PyTorch (.pt) pour rechargement sans ré-entraînement
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h', '15m')
|
||||
seq_len: Nombre de bougies par image (défaut 64)
|
||||
tp_atr_mult: Multiplicateur ATR pour le TP (génération labels)
|
||||
sl_atr_mult: Multiplicateur ATR pour le SL (génération labels)
|
||||
horizon: Nombre de barres pour évaluer TP/SL
|
||||
min_confidence: Seuil de confiance pour le signal tradeable
|
||||
epochs: Nombre maximum d'époques d'entraînement
|
||||
batch_size: Taille du batch (défaut 32)
|
||||
lr: Taux d'apprentissage Adam (défaut 1e-3)
|
||||
patience: Patience pour l'early stopping (défaut 7)
|
||||
"""
|
||||
|
||||
MODELS_DIR = MODELS_DIR
|
||||
CLASS_MAP = CLASS_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
seq_len: int = 64,
|
||||
tp_atr_mult: float = 2.0,
|
||||
sl_atr_mult: float = 1.0,
|
||||
horizon: int = 30,
|
||||
min_confidence: float = 0.55,
|
||||
epochs: int = 50,
|
||||
batch_size: int = 32,
|
||||
lr: float = 1e-3,
|
||||
patience: int = 7,
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.seq_len = seq_len
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.horizon = horizon
|
||||
self.min_confidence = min_confidence
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
self.lr = lr
|
||||
self.patience = patience
|
||||
|
||||
self.model: Optional['CandlestickCNN'] = None
|
||||
self.is_trained = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def train(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Entraîne le modèle CNN image sur les données OHLCV fournies.
|
||||
|
||||
Étapes :
|
||||
1. Génération des labels ATR-based (LabelGenerator)
|
||||
2. Encodage des fenêtres en images (CandlestickImageRenderer)
|
||||
3. Walk-forward évaluation (2 folds temporels)
|
||||
4. Entraînement final sur toutes les données
|
||||
5. Sauvegarde automatique
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV (minimum 200 barres recommandées,
|
||||
colonnes : open/high/low/close/volume)
|
||||
|
||||
Returns:
|
||||
Dict avec métriques : wf_accuracy, wf_precision, label_dist,
|
||||
n_samples, trained_at, et error si échec.
|
||||
"""
|
||||
# Vérifications de disponibilité
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||
if not SKLEARN_AVAILABLE:
|
||||
return {'error': 'scikit-learn non disponible'}
|
||||
|
||||
logger.info(
|
||||
f"Début entraînement CNNImageStrategyModel pour "
|
||||
f"{self.symbol}/{self.timeframe} — seq_len={self.seq_len}"
|
||||
)
|
||||
|
||||
# 1. Génération des labels (ATR-based, cohérent avec MLStrategyModel)
|
||||
gen = LabelGenerator(horizon=self.horizon)
|
||||
labels = gen.generate_atr_based(
|
||||
data,
|
||||
atr_tp_mult=self.tp_atr_mult,
|
||||
atr_sl_mult=self.sl_atr_mult,
|
||||
)
|
||||
|
||||
# Encodage [-1,0,1] → [0,1,2] pour CrossEntropyLoss
|
||||
y_enc = (labels + 1).values # np.ndarray
|
||||
|
||||
# 2. Encodage des images (fenêtres glissantes)
|
||||
logger.info(f" Rendu des images (seq_len={self.seq_len})...")
|
||||
renderer = CandlestickImageRenderer()
|
||||
X = renderer.encode(data, seq_len=self.seq_len) # (N, 3, 128, 128)
|
||||
|
||||
# 3. Alignement X et y
|
||||
# encode() produit N = len(data) - seq_len images
|
||||
# La i-ème image correspond à data.iloc[i : i + seq_len]
|
||||
# Le label associé est y_enc[i + seq_len - 1] (dernier point de la fenêtre)
|
||||
n_images = len(X)
|
||||
if n_images == 0:
|
||||
return {
|
||||
'error': f'Pas assez de données : {len(data)} barres < seq_len={self.seq_len}'
|
||||
}
|
||||
|
||||
# Index des labels pour chaque image : fenêtre i → label à l'indice i + seq_len - 1
|
||||
label_indices = np.arange(self.seq_len - 1, self.seq_len - 1 + n_images)
|
||||
# Supprimer les barres de fin (labels non fiables, horizon tronqué)
|
||||
valid_mask = label_indices < (len(y_enc) - self.horizon)
|
||||
|
||||
X = X[valid_mask]
|
||||
y = y_enc[label_indices[valid_mask]]
|
||||
|
||||
n_samples = len(X)
|
||||
if n_samples < 50:
|
||||
return {'error': f'Trop peu d\'échantillons valides : {n_samples}'}
|
||||
|
||||
logger.info(
|
||||
f" {n_samples} images valides — "
|
||||
f"LONG={(y==2).sum()}, SHORT={(y==0).sum()}, NEUTRAL={(y==1).sum()}"
|
||||
)
|
||||
|
||||
# 4. Calcul des class_weights (inversement proportionnels à la fréquence)
|
||||
class_weights = self._compute_class_weights(y)
|
||||
|
||||
# 5. Walk-forward évaluation (2 folds)
|
||||
wf_metrics = self._walk_forward_eval(X, y, class_weights, n_splits=2)
|
||||
|
||||
# 6. Entraînement final sur toutes les données
|
||||
logger.info(" Entraînement final sur toutes les données...")
|
||||
self.model = CandlestickCNN(n_classes=3)
|
||||
self._train_model(self.model, X, y, class_weights)
|
||||
self.is_trained = True
|
||||
|
||||
# Distribution des labels (décodage pour lisibilité)
|
||||
label_dist = {
|
||||
'long': int((y == 2).sum()),
|
||||
'short': int((y == 0).sum()),
|
||||
'neutral': int((y == 1).sum()),
|
||||
}
|
||||
|
||||
self.metadata = {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'seq_len': self.seq_len,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'n_samples': n_samples,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'label_dist': label_dist,
|
||||
'wf_metrics': wf_metrics,
|
||||
'mplfinance': MPLFINANCE_AVAILABLE,
|
||||
}
|
||||
|
||||
# 7. Sauvegarde
|
||||
self.save()
|
||||
|
||||
logger.info(
|
||||
f"Entraînement terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}"
|
||||
)
|
||||
return self.metadata
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Prédiction
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def predict(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal pour la dernière fenêtre disponible.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV récent (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Dict : {
|
||||
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
|
||||
'confidence': float [0..1] — max(P_LONG, P_SHORT),
|
||||
'probas': {'long': float, 'short': float, 'neutral': float},
|
||||
'tradeable': bool — confidence >= min_confidence et signal != 0,
|
||||
}
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'PyTorch non disponible'
|
||||
}
|
||||
if not self.is_trained or self.model is None:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'Modèle non entraîné'
|
||||
}
|
||||
|
||||
try:
|
||||
renderer = CandlestickImageRenderer()
|
||||
img = renderer.encode_last(data, seq_len=self.seq_len) # (1, 3, 128, 128)
|
||||
except ValueError as e:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
# Conversion en Tensor
|
||||
x_tensor = torch.from_numpy(img).float() # (1, 3, 128, 128)
|
||||
|
||||
# Prédiction
|
||||
proba_arr = self.model.predict_proba(x_tensor) # (1, 3)
|
||||
proba = proba_arr[0] # (3,) : [P_SHORT, P_NEUTRAL, P_LONG]
|
||||
|
||||
pred_class = int(np.argmax(proba))
|
||||
signal = self.CLASS_MAP[pred_class] # Décodage : 0→-1, 1→0, 2→1
|
||||
|
||||
probas = {
|
||||
'short': float(proba[0]),
|
||||
'neutral': float(proba[1]),
|
||||
'long': float(proba[2]),
|
||||
}
|
||||
confidence = float(max(probas['long'], probas['short']))
|
||||
|
||||
return {
|
||||
'signal': signal,
|
||||
'confidence': confidence,
|
||||
'probas': probas,
|
||||
'tradeable': confidence >= self.min_confidence and signal != 0,
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sauvegarde / Chargement
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def save(self) -> Path:
|
||||
"""
|
||||
Sauvegarde le modèle et ses métadonnées sur disque.
|
||||
|
||||
Format :
|
||||
{symbol}_{timeframe}.pt — state_dict PyTorch
|
||||
{symbol}_{timeframe}_meta.json — métadonnées JSON
|
||||
|
||||
Returns:
|
||||
Path vers le fichier .pt sauvegardé
|
||||
|
||||
Raises:
|
||||
RuntimeError si le modèle n'est pas entraîné ou PyTorch indisponible
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
if not self.is_trained or self.model is None:
|
||||
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||
|
||||
model_id = f"{self.symbol}_{self.timeframe}"
|
||||
model_path = self.MODELS_DIR / f"{model_id}.pt"
|
||||
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
# Sauvegarde du state_dict + config pour pouvoir reconstruire le modèle
|
||||
torch.save(
|
||||
{
|
||||
'state_dict': self.model.state_dict(),
|
||||
'config': {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'seq_len': self.seq_len,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'min_confidence': self.min_confidence,
|
||||
'epochs': self.epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'lr': self.lr,
|
||||
'patience': self.patience,
|
||||
},
|
||||
'metadata': self.metadata,
|
||||
},
|
||||
model_path
|
||||
)
|
||||
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Modèle CNN image sauvegardé : {model_path}")
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'CNNImageStrategyModel':
|
||||
"""
|
||||
Charge un modèle existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance CNNImageStrategyModel prête à prédire
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch non disponible
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle non trouvé : {model_path}")
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
cfg = checkpoint.get('config', {})
|
||||
|
||||
instance = cls(
|
||||
symbol = cfg.get('symbol', symbol),
|
||||
timeframe = cfg.get('timeframe', timeframe),
|
||||
seq_len = cfg.get('seq_len', 64),
|
||||
tp_atr_mult = cfg.get('tp_atr_mult', 2.0),
|
||||
sl_atr_mult = cfg.get('sl_atr_mult', 1.0),
|
||||
horizon = cfg.get('horizon', 30),
|
||||
min_confidence = cfg.get('min_confidence', 0.55),
|
||||
epochs = cfg.get('epochs', 50),
|
||||
batch_size = cfg.get('batch_size', 32),
|
||||
lr = cfg.get('lr', 1e-3),
|
||||
patience = cfg.get('patience', 7),
|
||||
)
|
||||
|
||||
# Reconstruction du modèle et chargement des poids
|
||||
instance.model = CandlestickCNN(n_classes=3)
|
||||
instance.model.load_state_dict(checkpoint['state_dict'])
|
||||
instance.model.eval()
|
||||
instance.is_trained = True
|
||||
instance.metadata = checkpoint.get('metadata', {})
|
||||
|
||||
logger.info(f"Modèle CNN image chargé depuis {model_path}")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def list_trained_models() -> List[Dict]:
|
||||
"""
|
||||
Retourne la liste des modèles entraînés disponibles.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec symbol, timeframe, trained_at, n_samples, wf_accuracy
|
||||
"""
|
||||
if not MODELS_DIR.exists():
|
||||
return []
|
||||
|
||||
models = []
|
||||
for f in MODELS_DIR.glob("*_meta.json"):
|
||||
try:
|
||||
with open(f) as fp:
|
||||
meta = json.load(fp)
|
||||
# Nom du fichier : EURUSD_1h_meta.json → EURUSD, 1h
|
||||
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
|
||||
parts = stem.split('_', 1)
|
||||
models.append({
|
||||
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||
'trained_at': meta.get('trained_at', '?'),
|
||||
'n_samples': meta.get('n_samples', 0),
|
||||
'seq_len': meta.get('seq_len', 64),
|
||||
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
|
||||
'mplfinance': meta.get('mplfinance', False),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return models
|
||||
|
||||
def get_feature_importance(self) -> List[Dict]:
|
||||
"""
|
||||
Retourne l'importance des features.
|
||||
|
||||
Note : Les CNN sont des boîtes noires — pas d'importance de feature
|
||||
interprétable comme XGBoost. Retourne une liste vide.
|
||||
|
||||
Returns:
|
||||
[] — CNN = boîte noire visuelle
|
||||
"""
|
||||
return []
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement interne
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _train_model(
|
||||
self,
|
||||
model: 'CandlestickCNN',
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
class_weights: 'torch.Tensor',
|
||||
) -> None:
|
||||
"""
|
||||
Entraîne le modèle CNN sur les données fournies avec early stopping.
|
||||
|
||||
Args:
|
||||
model: Instance CandlestickCNN à entraîner
|
||||
X: Images (N, 3, 128, 128) float32
|
||||
y: Labels encodés (N,) int, valeurs {0, 1, 2}
|
||||
class_weights: Poids de classes (3,) pour CrossEntropyLoss
|
||||
"""
|
||||
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
|
||||
torch.set_num_threads(4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
class_weights = class_weights.to(device)
|
||||
|
||||
# Conversion en Tensors
|
||||
X_t = torch.from_numpy(X).float()
|
||||
y_t = torch.from_numpy(y.astype(np.int64))
|
||||
|
||||
# Split train/validation (80/20 temporel)
|
||||
n_train = int(len(X_t) * 0.8)
|
||||
X_tr, X_val = X_t[:n_train], X_t[n_train:]
|
||||
y_tr, y_val = y_t[:n_train], y_t[n_train:]
|
||||
|
||||
# DataLoaders
|
||||
train_ds = TensorDataset(X_tr, y_tr)
|
||||
val_ds = TensorDataset(X_val, y_val)
|
||||
train_dl = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False)
|
||||
val_dl = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Optimiseur et critère
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
|
||||
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
patience_count = 0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
# --- Phase entraînement ---
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
for X_batch, y_batch in train_dl:
|
||||
X_batch = X_batch.to(device)
|
||||
y_batch = y_batch.to(device)
|
||||
optimizer.zero_grad()
|
||||
logits = model(X_batch)
|
||||
loss = criterion(logits, y_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item() * len(X_batch)
|
||||
train_loss /= max(len(train_ds), 1)
|
||||
|
||||
# --- Phase validation ---
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for X_batch, y_batch in val_dl:
|
||||
X_batch = X_batch.to(device)
|
||||
y_batch = y_batch.to(device)
|
||||
logits = model(X_batch)
|
||||
loss = criterion(logits, y_batch)
|
||||
val_loss += loss.item() * len(X_batch)
|
||||
val_loss /= max(len(val_ds), 1)
|
||||
|
||||
logger.debug(
|
||||
f" Epoch {epoch+1}/{self.epochs} — "
|
||||
f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}"
|
||||
)
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_count = 0
|
||||
best_state = {
|
||||
k: v.cpu().clone() for k, v in model.state_dict().items()
|
||||
}
|
||||
else:
|
||||
patience_count += 1
|
||||
if patience_count >= self.patience:
|
||||
logger.info(
|
||||
f" Early stopping à l'époque {epoch+1} "
|
||||
f"(patience={self.patience}, best_val_loss={best_val_loss:.4f})"
|
||||
)
|
||||
break
|
||||
|
||||
# Restauration des meilleurs poids
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
model.eval()
|
||||
model.to('cpu')
|
||||
|
||||
def _walk_forward_eval(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
class_weights: 'torch.Tensor',
|
||||
n_splits: int = 2,
|
||||
) -> Dict:
|
||||
"""
|
||||
Évalue le modèle en cross-validation temporelle (walk-forward).
|
||||
|
||||
Args:
|
||||
X: Images (N, 3, 128, 128)
|
||||
y: Labels encodés (N,)
|
||||
class_weights: Poids de classes pour CrossEntropyLoss
|
||||
n_splits: Nombre de folds temporels
|
||||
|
||||
Returns:
|
||||
Dict : avg_accuracy, avg_precision, fold_accuracies
|
||||
"""
|
||||
tscv = TimeSeriesSplit(n_splits=n_splits)
|
||||
accuracies = []
|
||||
precisions = []
|
||||
|
||||
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
|
||||
X_tr, X_te = X[train_idx], X[test_idx]
|
||||
y_tr, y_te = y[train_idx], y[test_idx]
|
||||
|
||||
# Modèle temporaire pour ce fold
|
||||
fold_model = CandlestickCNN(n_classes=3)
|
||||
self._train_model(fold_model, X_tr, y_tr, class_weights)
|
||||
|
||||
# Évaluation
|
||||
x_tensor = torch.from_numpy(X_te).float()
|
||||
proba_arr = fold_model.predict_proba(x_tensor) # (N_te, 3)
|
||||
y_pred = np.argmax(proba_arr, axis=1)
|
||||
|
||||
acc = float((y_pred == y_te).mean())
|
||||
accuracies.append(acc)
|
||||
|
||||
# Précision sur signaux directionnels (LONG=2 et SHORT=0, exclure NEUTRAL=1)
|
||||
mask = (y_te != 1) | (y_pred != 1)
|
||||
if mask.sum() > 0 and SKLEARN_AVAILABLE:
|
||||
prec, _, _, _ = precision_recall_fscore_support(
|
||||
y_te[mask], y_pred[mask],
|
||||
average='macro', zero_division=0
|
||||
)
|
||||
precisions.append(float(prec))
|
||||
else:
|
||||
precisions.append(0.0)
|
||||
|
||||
logger.info(
|
||||
f" Fold {fold+1}/{n_splits} : acc={acc:.2%}, "
|
||||
f"prec={precisions[-1]:.2%}"
|
||||
)
|
||||
|
||||
return {
|
||||
'avg_accuracy': float(np.mean(accuracies)) if accuracies else 0.0,
|
||||
'avg_precision': float(np.mean(precisions)) if precisions else 0.0,
|
||||
'fold_accuracies': [float(a) for a in accuracies],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor':
|
||||
"""
|
||||
Calcule les poids de classes inversement proportionnels à leur fréquence.
|
||||
|
||||
Permet de compenser les déséquilibres (ex: beaucoup de NEUTRAL, peu de trades).
|
||||
|
||||
Args:
|
||||
y: Labels encodés {0, 1, 2}
|
||||
|
||||
Returns:
|
||||
torch.Tensor de forme (3,) — poids pour CrossEntropyLoss
|
||||
"""
|
||||
weights = np.ones(3, dtype=np.float32)
|
||||
n_total = len(y)
|
||||
|
||||
if n_total == 0:
|
||||
return torch.from_numpy(weights)
|
||||
|
||||
for cls in range(3):
|
||||
count = (y == cls).sum()
|
||||
if count > 0:
|
||||
# Poids inversement proportionnel : classes rares → poids plus élevé
|
||||
weights[cls] = n_total / (3.0 * count)
|
||||
|
||||
# Normalisation pour que la somme des poids = 3
|
||||
weights = weights / weights.mean()
|
||||
|
||||
return torch.from_numpy(weights)
|
||||
3
src/ml/ensemble/__init__.py
Normal file
3
src/ml/ensemble/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ensemble_model import EnsembleModel
|
||||
|
||||
__all__ = ['EnsembleModel']
|
||||
262
src/ml/ensemble/ensemble_model.py
Normal file
262
src/ml/ensemble/ensemble_model.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Ensemble Model — Combine plusieurs modèles ML pour un signal de trading robuste.
|
||||
|
||||
L'EnsembleModel agrège les prédictions de modèles indépendants (XGBoost, CNN,
|
||||
et plus tard RL) via une moyenne pondérée. Un signal n'est émis que si les
|
||||
modèles actifs sont en accord ET que le score pondéré dépasse un seuil.
|
||||
|
||||
Duck typing : ce module n'importe PAS directement MLStrategyModel ni
|
||||
CNNStrategyModel. Tout objet exposant `.predict(df)` → dict et `.is_trained`
|
||||
→ bool est compatible.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnsembleModel:
|
||||
"""
|
||||
Combine plusieurs modèles ML pour produire un signal de trading robuste.
|
||||
|
||||
Logique :
|
||||
- Chaque modèle prédit indépendamment (signal + confidence)
|
||||
- Score final = somme pondérée des confidences pour les modèles en accord
|
||||
- Signal validé uniquement si :
|
||||
1. Au moins 2 modèles actifs sont en accord sur la direction
|
||||
2. Score pondéré >= min_confidence
|
||||
|
||||
Poids par défaut : xgboost=0.40, cnn=0.60 (CNN légèrement favorisé car
|
||||
il voit les données brutes sans biais de feature engineering)
|
||||
"""
|
||||
|
||||
DEFAULT_WEIGHTS = {
|
||||
'xgboost': 0.30,
|
||||
'cnn': 0.30,
|
||||
'cnn_image': 0.40, # CNN Vision — favorisé (patterns visuels bruts)
|
||||
'rl': 0.00, # Réservé Phase 4d
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights: Optional[Dict[str, float]] = None,
|
||||
min_confidence: float = 0.60,
|
||||
require_agreement: bool = True,
|
||||
):
|
||||
self.weights = dict(weights) if weights else dict(self.DEFAULT_WEIGHTS)
|
||||
self.min_confidence = min_confidence
|
||||
self.require_agreement = require_agreement
|
||||
|
||||
# Modèles attachés (duck typing : .predict(df), .is_trained)
|
||||
self._models: Dict[str, Any] = {}
|
||||
|
||||
# Référence directe au modèle CNN Image (accès rapide pour auto-attach)
|
||||
self._cnn_image_model = None
|
||||
|
||||
logger.info(
|
||||
f"EnsembleModel initialisé — poids={self.weights}, "
|
||||
f"seuil={self.min_confidence}, accord_requis={self.require_agreement}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachement des modèles
|
||||
# ------------------------------------------------------------------
|
||||
def attach_xgboost(self, model) -> None:
|
||||
"""Attache un MLStrategyModel entraîné."""
|
||||
self._attach('xgboost', model)
|
||||
|
||||
def attach_cnn(self, model) -> None:
|
||||
"""Attache un CNNStrategyModel entraîné."""
|
||||
self._attach('cnn', model)
|
||||
|
||||
def attach_cnn_image(self, model) -> None:
|
||||
"""Attache un CNNImageStrategyModel entraîné (Phase 4c-bis)."""
|
||||
self._cnn_image_model = model
|
||||
self._attach('cnn_image', model)
|
||||
|
||||
def attach_rl(self, model) -> None:
|
||||
"""Attache un agent RL (Phase 4d)."""
|
||||
self._attach('rl', model)
|
||||
|
||||
def _attach(self, name: str, model) -> None:
|
||||
"""Attache un modèle générique avec vérification duck typing."""
|
||||
if not hasattr(model, 'predict') or not callable(model.predict):
|
||||
raise ValueError(f"Le modèle '{name}' doit exposer une méthode predict()")
|
||||
if not hasattr(model, 'is_trained'):
|
||||
raise ValueError(f"Le modèle '{name}' doit exposer un attribut is_trained")
|
||||
self._models[name] = model
|
||||
# Ajouter le poids par défaut s'il n'existe pas
|
||||
if name not in self.weights:
|
||||
self.weights[name] = 0.0
|
||||
logger.warning(f"Poids pour '{name}' non défini — initialisé à 0.0")
|
||||
logger.info(f"Modèle '{name}' attaché (is_trained={model.is_trained})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prédiction combinée
|
||||
# ------------------------------------------------------------------
|
||||
def predict(self, df: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal combiné à partir de tous les modèles actifs.
|
||||
|
||||
Returns:
|
||||
{
|
||||
'signal': int, # 1 LONG, -1 SHORT, 0 NEUTRAL
|
||||
'confidence': float, # score pondéré [0..1]
|
||||
'tradeable': bool,
|
||||
'agreement': bool, # True si tous les modèles actifs concordent
|
||||
'components': dict, # résultats individuels par modèle
|
||||
}
|
||||
"""
|
||||
components: Dict[str, Dict] = {}
|
||||
neutral_result = {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': False, 'components': components,
|
||||
}
|
||||
|
||||
# 1. Collecter les prédictions des modèles disponibles et entraînés
|
||||
for name, model in self._models.items():
|
||||
if not model.is_trained:
|
||||
logger.debug(f"Ensemble : modèle '{name}' non entraîné, ignoré")
|
||||
continue
|
||||
if self.weights.get(name, 0.0) <= 0.0:
|
||||
logger.debug(f"Ensemble : modèle '{name}' poids=0, ignoré")
|
||||
continue
|
||||
try:
|
||||
result = model.predict(df)
|
||||
components[name] = {
|
||||
'signal': result.get('signal', 0),
|
||||
'confidence': result.get('confidence', 0.0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Ensemble : erreur predict '{name}' — {e}")
|
||||
continue
|
||||
|
||||
if not components:
|
||||
logger.debug("Ensemble : aucun modèle actif n'a produit de prédiction")
|
||||
neutral_result['components'] = components
|
||||
return neutral_result
|
||||
|
||||
# 2. Filtrer les signaux non-neutres
|
||||
directional = {
|
||||
k: v for k, v in components.items() if v['signal'] != 0
|
||||
}
|
||||
|
||||
if not directional:
|
||||
# Tous les modèles sont neutres
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': True, 'components': components,
|
||||
}
|
||||
|
||||
# 3. Vérifier l'accord entre modèles directionnels
|
||||
directions = set(v['signal'] for v in directional.values())
|
||||
agreement = len(directions) == 1
|
||||
|
||||
if self.require_agreement and not agreement:
|
||||
logger.debug(
|
||||
f"Ensemble : désaccord entre modèles — {directional}"
|
||||
)
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': False, 'components': components,
|
||||
}
|
||||
|
||||
# 4. Vérifier qu'au moins 2 modèles actifs sont en accord
|
||||
if len(directional) < 2:
|
||||
logger.debug("Ensemble : un seul modèle directionnel, signal insuffisant")
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': True, 'components': components,
|
||||
}
|
||||
|
||||
# 5. Calculer le score pondéré (normalisé sur les modèles actifs)
|
||||
consensus_dir = directions.pop() # direction unique
|
||||
total_weight = sum(self.weights.get(k, 0.0) for k in directional)
|
||||
|
||||
if total_weight <= 0:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': agreement, 'components': components,
|
||||
}
|
||||
|
||||
weighted_score = sum(
|
||||
self.weights.get(k, 0.0) * v['confidence']
|
||||
for k, v in directional.items()
|
||||
) / total_weight
|
||||
|
||||
# 6. Signal final
|
||||
tradeable = weighted_score >= self.min_confidence
|
||||
|
||||
logger.info(
|
||||
f"Ensemble : direction={'LONG' if consensus_dir == 1 else 'SHORT'} | "
|
||||
f"score={weighted_score:.2%} | accord={agreement} | tradeable={tradeable}"
|
||||
)
|
||||
|
||||
return {
|
||||
'signal': consensus_dir,
|
||||
'confidence': weighted_score,
|
||||
'tradeable': tradeable,
|
||||
'agreement': agreement,
|
||||
'components': components,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Statut et configuration
|
||||
# ------------------------------------------------------------------
|
||||
def is_ready(self) -> bool:
|
||||
"""True si au moins 2 modèles parmi {xgboost, cnn, cnn_image} sont attachés et entraînés."""
|
||||
eligible_names = {'xgboost', 'cnn', 'cnn_image'}
|
||||
trained = sum(
|
||||
1 for name, model in self._models.items()
|
||||
if name in eligible_names
|
||||
and model.is_trained
|
||||
and self.weights.get(name, 0.0) > 0
|
||||
)
|
||||
return trained >= 2
|
||||
|
||||
def get_status(self) -> Dict:
|
||||
"""Statut de chaque composant + poids actifs."""
|
||||
status = {
|
||||
'ready': self.is_ready(),
|
||||
'min_confidence': self.min_confidence,
|
||||
'require_agreement': self.require_agreement,
|
||||
'weights': dict(self.weights),
|
||||
'models': {},
|
||||
}
|
||||
for name, model in self._models.items():
|
||||
status['models'][name] = {
|
||||
'attached': True,
|
||||
'is_trained': model.is_trained,
|
||||
'weight': self.weights.get(name, 0.0),
|
||||
}
|
||||
# Modèles non attachés mais présents dans les poids
|
||||
for name in self.weights:
|
||||
if name not in status['models']:
|
||||
status['models'][name] = {
|
||||
'attached': False,
|
||||
'is_trained': False,
|
||||
'weight': self.weights[name],
|
||||
}
|
||||
return status
|
||||
|
||||
def update_weights(self, weights: Dict[str, float]) -> None:
|
||||
"""
|
||||
Mise à jour dynamique des poids.
|
||||
|
||||
Si la somme != 1.0, normalise automatiquement et log un warning.
|
||||
"""
|
||||
total = sum(weights.values())
|
||||
if total <= 0:
|
||||
raise ValueError("La somme des poids doit être > 0")
|
||||
|
||||
if abs(total - 1.0) > 1e-6:
|
||||
logger.warning(
|
||||
f"Somme des poids = {total:.4f} != 1.0 — normalisation automatique"
|
||||
)
|
||||
weights = {k: v / total for k, v in weights.items()}
|
||||
|
||||
self.weights.update(weights)
|
||||
logger.info(f"Poids mis à jour : {self.weights}")
|
||||
@@ -140,30 +140,59 @@ class LabelGenerator:
|
||||
sl_short: float,
|
||||
) -> int:
|
||||
"""
|
||||
Parcourt les barres futures bar par bar et retourne le label.
|
||||
Vérifie HIGH pour TP LONG et LOW pour SL LONG (et inversement pour SHORT).
|
||||
Simule LONG et SHORT de façon indépendante sur les barres futures.
|
||||
|
||||
LONG et SHORT sont deux trades hypothétiques distincts : le SL du LONG
|
||||
(prix baisse) ne signifie pas que le SL du SHORT (prix monte) est touché.
|
||||
Les deux simulations sont donc parcourues séparément pour éviter de
|
||||
manquer les signaux SHORT quand le prix descend.
|
||||
|
||||
Retourne le label du trade gagnant qui se résout en premier :
|
||||
1 (LONG), -1 (SHORT) ou 0 (NEUTRAL).
|
||||
"""
|
||||
for _, bar in future.iterrows():
|
||||
# LONG : TP atteint ?
|
||||
if bar['high'] >= tp_long and bar['low'] > sl_long:
|
||||
return 1
|
||||
# LONG : SL atteint en premier ?
|
||||
if bar['low'] <= sl_long:
|
||||
# Vérifie si TP atteint le même bar (candle ambiguë)
|
||||
if bar['high'] >= tp_long:
|
||||
return 0 # Ambigu → neutre
|
||||
return 0 # SL touché → pas de LONG
|
||||
# --- Simulation LONG indépendante ---
|
||||
long_win_idx = None
|
||||
long_lose_idx = None
|
||||
for idx, (_, bar) in enumerate(future.iterrows()):
|
||||
tp_hit = bar['high'] >= tp_long
|
||||
sl_hit = bar['low'] <= sl_long
|
||||
if tp_hit and sl_hit:
|
||||
long_lose_idx = idx # Barre ambiguë → perte
|
||||
break
|
||||
if tp_hit:
|
||||
long_win_idx = idx
|
||||
break
|
||||
if sl_hit:
|
||||
long_lose_idx = idx
|
||||
break
|
||||
|
||||
# SHORT : TP atteint ?
|
||||
if bar['low'] <= tp_short and bar['high'] < sl_short:
|
||||
return -1
|
||||
# SHORT : SL atteint en premier ?
|
||||
if bar['high'] >= sl_short:
|
||||
if bar['low'] <= tp_short:
|
||||
return 0
|
||||
return 0
|
||||
# --- Simulation SHORT indépendante ---
|
||||
short_win_idx = None
|
||||
short_lose_idx = None
|
||||
for idx, (_, bar) in enumerate(future.iterrows()):
|
||||
tp_hit = bar['low'] <= tp_short
|
||||
sl_hit = bar['high'] >= sl_short
|
||||
if tp_hit and sl_hit:
|
||||
short_lose_idx = idx # Barre ambiguë → perte
|
||||
break
|
||||
if tp_hit:
|
||||
short_win_idx = idx
|
||||
break
|
||||
if sl_hit:
|
||||
short_lose_idx = idx
|
||||
break
|
||||
|
||||
return 0 # Ni TP ni SL atteint dans l'horizon
|
||||
long_won = long_win_idx is not None
|
||||
short_won = short_win_idx is not None
|
||||
|
||||
if long_won and not short_won:
|
||||
return 1
|
||||
if short_won and not long_won:
|
||||
return -1
|
||||
if long_won and short_won:
|
||||
# Les deux trades seraient gagnants : prendre celui qui se résout en premier
|
||||
return 1 if long_win_idx <= short_win_idx else -1
|
||||
return 0 # Aucun TP atteint dans l'horizon
|
||||
|
||||
@staticmethod
|
||||
def _log_distribution(labels: pd.Series) -> None:
|
||||
|
||||
@@ -156,19 +156,23 @@ class MLStrategyModel:
|
||||
logger.info(f" {len(X)} échantillons, {len(self.feature_names)} features")
|
||||
logger.info(f" Distribution : LONG={( y==1).sum()}, SHORT={(y==-1).sum()}, NEUTRAL={(y==0).sum()}")
|
||||
|
||||
# Encodage labels [-1,0,1] → [0,1,2] (requis par XGBoost ≥ 2.x)
|
||||
y_enc = y + 1
|
||||
|
||||
# 3. Walk-forward cross-validation (3 folds temporels)
|
||||
wf_metrics = self._walk_forward_eval(X, y, n_splits=3)
|
||||
wf_metrics = self._walk_forward_eval(X, y_enc, n_splits=3)
|
||||
|
||||
# 4. Entraînement sur la totalité des données
|
||||
self.scaler = StandardScaler()
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
|
||||
self.model = self._build_model()
|
||||
self.model.fit(X_scaled, y)
|
||||
self.model.fit(X_scaled, y_enc)
|
||||
self.is_trained = True
|
||||
|
||||
# 5. Évaluation finale (in-sample — indicative)
|
||||
y_pred = self.model.predict(X_scaled)
|
||||
y_pred_enc = self.model.predict(X_scaled)
|
||||
y_pred = y_pred_enc - 1 # décodage [0,1,2] → [-1,0,1]
|
||||
report = classification_report(y, y_pred, labels=[-1, 0, 1],
|
||||
target_names=['SHORT', 'NEUTRAL', 'LONG'],
|
||||
output_dict=True, zero_division=0)
|
||||
@@ -235,24 +239,25 @@ class MLStrategyModel:
|
||||
last = last[self.feature_names].fillna(0)
|
||||
|
||||
X_scaled = self.scaler.transform(last)
|
||||
pred = self.model.predict(X_scaled)[0]
|
||||
pred_enc = self.model.predict(X_scaled)[0]
|
||||
pred = int(pred_enc) - 1 # décodage [0,1,2] → [-1,0,1]
|
||||
|
||||
# Probabilités si disponibles
|
||||
probas = {'long': 0.0, 'short': 0.0, 'neutral': 1.0}
|
||||
confidence = 0.0
|
||||
if hasattr(self.model, 'predict_proba'):
|
||||
proba_arr = self.model.predict_proba(X_scaled)[0]
|
||||
classes = list(self.model.classes_)
|
||||
classes = list(self.model.classes_) # [0, 1, 2] encodés
|
||||
prob_map = {c: p for c, p in zip(classes, proba_arr)}
|
||||
probas = {
|
||||
'long': float(prob_map.get(1, 0.0)),
|
||||
'short': float(prob_map.get(-1, 0.0)),
|
||||
'neutral': float(prob_map.get(0, 1.0)),
|
||||
'long': float(prob_map.get(2, 0.0)), # encodé 2 = LONG (1)
|
||||
'short': float(prob_map.get(0, 0.0)), # encodé 0 = SHORT (-1)
|
||||
'neutral': float(prob_map.get(1, 1.0)), # encodé 1 = NEUTRAL (0)
|
||||
}
|
||||
confidence = float(max(probas['long'], probas['short']))
|
||||
|
||||
return {
|
||||
'signal': int(pred),
|
||||
'signal': pred,
|
||||
'confidence': confidence,
|
||||
'probas': probas,
|
||||
'tradeable': confidence >= self.min_confidence and pred != 0,
|
||||
@@ -383,7 +388,8 @@ class MLStrategyModel:
|
||||
|
||||
acc = (y_pred == y_te.values).mean()
|
||||
# Précision/Recall sur les signaux directionnels uniquement
|
||||
mask = (y_te != 0) | (y_pred != 0)
|
||||
# y encodé : 0=SHORT, 1=NEUTRAL, 2=LONG → NEUTRAL=1
|
||||
mask = (y_te != 1) | (y_pred != 1)
|
||||
prec, rec, _, _ = precision_recall_fscore_support(
|
||||
y_te[mask], y_pred[mask], average='macro', zero_division=0
|
||||
) if mask.sum() > 0 else (0, 0, 0, 0)
|
||||
|
||||
@@ -11,10 +11,12 @@ les différents régimes de marché:
|
||||
Permet d'adapter les stratégies selon le régime actuel.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
try:
|
||||
@@ -24,8 +26,18 @@ except ImportError:
|
||||
HMMLEARN_AVAILABLE = False
|
||||
logging.warning("hmmlearn not installed. Install with: pip install hmmlearn")
|
||||
|
||||
try:
|
||||
import joblib
|
||||
JOBLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
JOBLIB_AVAILABLE = False
|
||||
logging.warning("joblib not installed. La persistance HMM sera désactivée.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de persistance des modèles HMM
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent / "models" / "hmm"
|
||||
|
||||
|
||||
class RegimeDetector:
|
||||
"""
|
||||
@@ -79,6 +91,9 @@ class RegimeDetector:
|
||||
|
||||
self.is_fitted = False
|
||||
self.feature_names = []
|
||||
# Métadonnées d'entraînement pour la persistance
|
||||
self._trained_at: Optional[datetime] = None
|
||||
self._n_samples: int = 0
|
||||
|
||||
logger.info(f"RegimeDetector initialized with {n_regimes} regimes")
|
||||
|
||||
@@ -115,11 +130,147 @@ class RegimeDetector:
|
||||
try:
|
||||
self.model.fit(X)
|
||||
self.is_fitted = True
|
||||
logger.info("✅ HMM model fitted successfully")
|
||||
self._trained_at = datetime.now()
|
||||
self._n_samples = len(X)
|
||||
logger.info("Modèle HMM entraîné avec succès")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fitting HMM: {e}")
|
||||
logger.error(f"Erreur lors de l'entraînement HMM : {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def is_trained(self) -> bool:
|
||||
"""True si le modèle HMM a été ajusté (fit)."""
|
||||
return self.is_fitted
|
||||
|
||||
def needs_retrain(self, max_age_hours: int = 24) -> bool:
|
||||
"""
|
||||
Indique si le modèle doit être ré-entraîné.
|
||||
|
||||
Un ré-entraînement est nécessaire si :
|
||||
- Le modèle n'a jamais été entraîné
|
||||
- La date d'entraînement est inconnue
|
||||
- Le modèle est plus vieux que max_age_hours
|
||||
|
||||
Args:
|
||||
max_age_hours: Âge maximum du modèle en heures (défaut 24h)
|
||||
|
||||
Returns:
|
||||
True si un ré-entraînement est nécessaire
|
||||
"""
|
||||
if not self.is_fitted or self._trained_at is None:
|
||||
return True
|
||||
age = datetime.now() - self._trained_at
|
||||
return age > timedelta(hours=max_age_hours)
|
||||
|
||||
def save(self, symbol: str, timeframe: str) -> bool:
|
||||
"""
|
||||
Sauvegarde le modèle HMM entraîné sur disque avec joblib.
|
||||
|
||||
Le modèle est sauvegardé dans :
|
||||
models/hmm/{symbol}_{timeframe}.joblib
|
||||
|
||||
Les métadonnées (date, n_samples, n_components, labels) sont
|
||||
stockées dans un fichier JSON compagnon :
|
||||
models/hmm/{symbol}_{timeframe}_meta.json
|
||||
|
||||
Args:
|
||||
symbol: Symbole de l'instrument (ex : "EURUSD")
|
||||
timeframe: Unité de temps (ex : "1h")
|
||||
|
||||
Returns:
|
||||
True si la sauvegarde a réussi, False sinon
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
logger.warning("Impossible de sauvegarder : modèle non entraîné")
|
||||
return False
|
||||
|
||||
if not JOBLIB_AVAILABLE:
|
||||
logger.warning("joblib indisponible — sauvegarde HMM ignorée")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Créer le répertoire si nécessaire
|
||||
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
base = f"{symbol}_{timeframe}"
|
||||
model_path = MODELS_DIR / f"{base}.joblib"
|
||||
meta_path = MODELS_DIR / f"{base}_meta.json"
|
||||
|
||||
# Sauvegarder le modèle HMM + feature_names
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"feature_names": self.feature_names,
|
||||
"n_regimes": self.n_regimes,
|
||||
"random_state": self.random_state,
|
||||
}
|
||||
joblib.dump(payload, model_path)
|
||||
|
||||
# Sauvegarder les métadonnées en JSON
|
||||
meta = {
|
||||
"trained_at": self._trained_at.isoformat() if self._trained_at else None,
|
||||
"n_samples": self._n_samples,
|
||||
"n_components": self.n_regimes,
|
||||
"regime_labels": self.REGIME_NAMES,
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
}
|
||||
meta_path.write_text(json.dumps(meta, indent=2, default=str))
|
||||
|
||||
logger.info(f"Modèle HMM sauvegardé : {model_path}")
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur lors de la sauvegarde du modèle HMM : {exc}")
|
||||
return False
|
||||
|
||||
def load(self, symbol: str, timeframe: str) -> bool:
|
||||
"""
|
||||
Charge un modèle HMM depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Symbole de l'instrument (ex : "EURUSD")
|
||||
timeframe: Unité de temps (ex : "1h")
|
||||
|
||||
Returns:
|
||||
True si le chargement a réussi, False sinon
|
||||
"""
|
||||
if not JOBLIB_AVAILABLE:
|
||||
logger.warning("joblib indisponible — chargement HMM impossible")
|
||||
return False
|
||||
|
||||
base = f"{symbol}_{timeframe}"
|
||||
model_path = MODELS_DIR / f"{base}.joblib"
|
||||
meta_path = MODELS_DIR / f"{base}_meta.json"
|
||||
|
||||
if not model_path.exists():
|
||||
logger.debug(f"Aucun modèle HMM trouvé pour {symbol}/{timeframe}")
|
||||
return False
|
||||
|
||||
try:
|
||||
payload = joblib.load(model_path)
|
||||
self.model = payload["model"]
|
||||
self.feature_names = payload["feature_names"]
|
||||
self.n_regimes = payload["n_regimes"]
|
||||
self.random_state = payload["random_state"]
|
||||
self.is_fitted = True
|
||||
|
||||
# Charger les métadonnées si disponibles
|
||||
if meta_path.exists():
|
||||
meta = json.loads(meta_path.read_text())
|
||||
trained_at_raw = meta.get("trained_at")
|
||||
self._trained_at = (
|
||||
datetime.fromisoformat(trained_at_raw) if trained_at_raw else None
|
||||
)
|
||||
self._n_samples = meta.get("n_samples", 0)
|
||||
|
||||
logger.info(f"Modèle HMM chargé depuis {model_path} (entraîné le {self._trained_at})")
|
||||
return True
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur lors du chargement du modèle HMM : {exc}")
|
||||
self.is_fitted = False
|
||||
return False
|
||||
|
||||
def predict_regime(self, data: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Prédit les régimes pour toutes les barres.
|
||||
|
||||
24
src/ml/rl/__init__.py
Normal file
24
src/ml/rl/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Module RL (Reinforcement Learning) — Agent PPO pour le trading algorithmique.
|
||||
|
||||
Ce module implémente un agent PPO (Proximal Policy Optimization) entraîné
|
||||
par renforcement sur un environnement de trading simulé.
|
||||
|
||||
Composants :
|
||||
TradingEnv — Environnement gymnasium conforme (observation 20 features)
|
||||
PPOModel — Réseau Actor-Critic MLP 256→128→64 + entraînement PPO
|
||||
RLStrategyModel — Interface unifiée (identique à MLStrategyModel)
|
||||
"""
|
||||
|
||||
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE
|
||||
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel, RL_AVAILABLE
|
||||
|
||||
__all__ = [
|
||||
'TradingEnv',
|
||||
'PPOModel',
|
||||
'RLStrategyModel',
|
||||
'RL_AVAILABLE',
|
||||
'TORCH_AVAILABLE',
|
||||
'GYM_AVAILABLE',
|
||||
]
|
||||
681
src/ml/rl/ppo_model.py
Normal file
681
src/ml/rl/ppo_model.py
Normal file
@@ -0,0 +1,681 @@
|
||||
"""
|
||||
PPOModel — Implémentation manuelle de l'algorithme PPO (Proximal Policy Optimization).
|
||||
|
||||
Architecture Actor-Critic :
|
||||
- Réseau partagé : MLP 3 couches (256 → 128 → 64) avec BatchNorm + ReLU
|
||||
- Tête Actor : couche linéaire → logits (n_actions=3)
|
||||
- Tête Critic : couche linéaire → valeur scalaire V(s)
|
||||
|
||||
Hyperparamètres PPO :
|
||||
- clip ε=0.2 — clip ratio de probabilité pour éviter les grandes mises à jour
|
||||
- entropy_coef=0.01 — bonus d'entropie pour exploration
|
||||
- value_coef=0.5 — coefficient de la loss valeur
|
||||
- n_steps=2048 — nombre de transitions collectées avant mise à jour
|
||||
- n_epochs=10 — passes sur les données collectées
|
||||
- batch_size=64 — taille des mini-batchs
|
||||
- gamma=0.99 — facteur d'actualisation
|
||||
- gae_lambda=0.95 — λ pour Generalized Advantage Estimation
|
||||
|
||||
Sauvegarde :
|
||||
- models/rl_strategy/{symbol}_{timeframe}.pt (state_dict + config)
|
||||
- models/rl_strategy/{symbol}_{timeframe}_meta.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de sauvegarde des modèles RL
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "rl_strategy"
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Import conditionnel PyTorch
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.distributions import Categorical
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
optim = None
|
||||
Categorical = None
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Réseau Actor-Critic
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
class ActorCriticNetwork(nn.Module):
|
||||
"""
|
||||
Réseau Actor-Critic partagé pour PPO.
|
||||
|
||||
Architecture :
|
||||
- Tronc commun : Linear(obs_dim → 256) → BN → ReLU
|
||||
Linear(256 → 128) → BN → ReLU
|
||||
Linear(128 → 64) → ReLU
|
||||
- Tête Actor : Linear(64 → n_actions)
|
||||
- Tête Critic : Linear(64 → 1)
|
||||
|
||||
Args:
|
||||
obs_dim: Dimension de l'espace d'observation
|
||||
n_actions: Nombre d'actions discrètes
|
||||
"""
|
||||
|
||||
def __init__(self, obs_dim: int = 20, n_actions: int = 3):
|
||||
super().__init__()
|
||||
|
||||
# Tronc commun
|
||||
self.trunk = nn.Sequential(
|
||||
nn.Linear(obs_dim, 256),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 128),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Tête Actor : logits pour chaque action
|
||||
self.actor_head = nn.Linear(64, n_actions)
|
||||
|
||||
# Tête Critic : estimation de la valeur d'état V(s)
|
||||
self.critic_head = nn.Linear(64, 1)
|
||||
|
||||
# Initialisation des poids (orthogonale pour stabilité PPO)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialisation orthogonale recommandée pour les réseaux PPO."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
# Gain plus faible pour les têtes finales
|
||||
nn.init.orthogonal_(self.actor_head.weight, gain=0.01)
|
||||
nn.init.orthogonal_(self.critic_head.weight, gain=1.0)
|
||||
|
||||
def forward(
|
||||
self, x: 'torch.Tensor'
|
||||
) -> Tuple['torch.Tensor', 'torch.Tensor']:
|
||||
"""
|
||||
Calcule les logits actor et la valeur critic.
|
||||
|
||||
Args:
|
||||
x: Tensor d'observations (batch_size, obs_dim)
|
||||
|
||||
Returns:
|
||||
Tuple (logits_actor, value_critic)
|
||||
logits_actor : (batch_size, n_actions)
|
||||
value_critic : (batch_size, 1)
|
||||
"""
|
||||
features = self.trunk(x)
|
||||
logits = self.actor_head(features)
|
||||
value = self.critic_head(features)
|
||||
return logits, value
|
||||
|
||||
def get_action_and_value(
|
||||
self, x: 'torch.Tensor', action: Optional['torch.Tensor'] = None
|
||||
) -> Tuple['torch.Tensor', 'torch.Tensor', 'torch.Tensor', 'torch.Tensor']:
|
||||
"""
|
||||
Retourne l'action, son log-prob, l'entropie et la valeur.
|
||||
|
||||
Args:
|
||||
x: Observations (batch_size, obs_dim)
|
||||
action: Si fourni, calcule le log-prob de cet action existante
|
||||
(pour la phase de mise à jour PPO)
|
||||
|
||||
Returns:
|
||||
Tuple (action, log_prob, entropy, value)
|
||||
"""
|
||||
logits, value = self.forward(x)
|
||||
dist = Categorical(logits=logits)
|
||||
|
||||
if action is None:
|
||||
action = dist.sample()
|
||||
|
||||
log_prob = dist.log_prob(action)
|
||||
entropy = dist.entropy()
|
||||
|
||||
return action, log_prob, entropy, value.squeeze(-1)
|
||||
|
||||
def get_value(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||
"""Retourne uniquement la valeur critic V(s)."""
|
||||
_, value = self.forward(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
|
||||
class PPOModel:
|
||||
"""
|
||||
Agent PPO (Proximal Policy Optimization) pour le trading.
|
||||
|
||||
Implémentation manuelle sans stable-baselines3, utilisant PyTorch pur.
|
||||
Suit le pseudo-code de Schulman et al. (2017) avec GAE.
|
||||
|
||||
Méthodes publiques :
|
||||
train(env, total_timesteps) — Lance l'entraînement
|
||||
predict(obs) — Retourne l'action et le log-prob
|
||||
save(symbol, timeframe) — Sauvegarde le modèle
|
||||
load(symbol, timeframe) — Charge un modèle existant (classmethod)
|
||||
|
||||
Args:
|
||||
obs_dim: Dimension de l'observation (défaut: 20)
|
||||
n_actions: Nombre d'actions discrètes (défaut: 3)
|
||||
lr: Taux d'apprentissage Adam (défaut: 3e-4)
|
||||
n_steps: Transitions par rollout avant mise à jour (défaut: 2048)
|
||||
n_epochs: Passes d'optimisation par rollout (défaut: 10)
|
||||
batch_size: Taille des mini-batchs (défaut: 64)
|
||||
gamma: Facteur d'actualisation (défaut: 0.99)
|
||||
gae_lambda: Lambda GAE (défaut: 0.95)
|
||||
clip_eps: Epsilon de clip PPO (défaut: 0.2)
|
||||
entropy_coef: Coefficient bonus entropie (défaut: 0.01)
|
||||
value_coef: Coefficient loss valeur (défaut: 0.5)
|
||||
max_grad_norm: Norme maximale du gradient (défaut: 0.5)
|
||||
"""
|
||||
|
||||
MODELS_DIR = MODELS_DIR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obs_dim: int = 20,
|
||||
n_actions: int = 3,
|
||||
lr: float = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
n_epochs: int = 10,
|
||||
batch_size: int = 64,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_eps: float = 0.2,
|
||||
entropy_coef: float = 0.01,
|
||||
value_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
):
|
||||
self.obs_dim = obs_dim
|
||||
self.n_actions = n_actions
|
||||
self.lr = lr
|
||||
self.n_steps = n_steps
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.clip_eps = clip_eps
|
||||
self.entropy_coef = entropy_coef
|
||||
self.value_coef = value_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
self.network: Optional['ActorCriticNetwork'] = None
|
||||
self.optimizer: Optional['optim.Adam'] = None
|
||||
self.is_trained: bool = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
self._init_network()
|
||||
|
||||
def _init_network(self) -> None:
|
||||
"""Initialise le réseau Actor-Critic et l'optimiseur Adam."""
|
||||
self.network = ActorCriticNetwork(self.obs_dim, self.n_actions)
|
||||
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr, eps=1e-5)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Entraînement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def train(
|
||||
self,
|
||||
env,
|
||||
total_timesteps: int = 100_000,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
) -> Dict:
|
||||
"""
|
||||
Entraîne l'agent PPO sur l'environnement de trading fourni.
|
||||
|
||||
Algorithme :
|
||||
Boucle principale :
|
||||
1. Collecte de n_steps transitions (rollout)
|
||||
2. Calcul des avantages GAE
|
||||
3. n_epochs passes d'optimisation sur mini-batchs mélangés
|
||||
Fin :
|
||||
Sauvegarde du modèle et retour des métriques
|
||||
|
||||
Args:
|
||||
env: Environnement TradingEnv
|
||||
total_timesteps: Nombre total de transitions à collecter
|
||||
symbol: Symbole pour la sauvegarde (ex: 'EURUSD')
|
||||
timeframe: Timeframe pour la sauvegarde (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Dict avec métriques : total_timesteps, n_updates, mean_reward,
|
||||
mean_ep_return, policy_loss, value_loss, entropy_loss
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||
|
||||
# Limiter les threads CPU pour ne pas saturer l'event loop FastAPI
|
||||
torch.set_num_threads(4)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.network.to(device)
|
||||
|
||||
logger.info(
|
||||
f"Début entraînement PPO {symbol}/{timeframe} — "
|
||||
f"total_timesteps={total_timesteps}, n_steps={self.n_steps}, device={device}"
|
||||
)
|
||||
|
||||
# ── Buffers de rollout ──────────────────────────────────────────────
|
||||
obs_buf = np.zeros((self.n_steps, self.obs_dim), dtype=np.float32)
|
||||
actions_buf = np.zeros((self.n_steps,), dtype=np.int64)
|
||||
rewards_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
dones_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
values_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
logprobs_buf = np.zeros((self.n_steps,), dtype=np.float32)
|
||||
|
||||
# ── Statistiques d'entraînement ─────────────────────────────────────
|
||||
all_episode_returns = []
|
||||
all_policy_losses = []
|
||||
all_value_losses = []
|
||||
all_entropy_losses = []
|
||||
ep_return = 0.0
|
||||
n_updates = 0
|
||||
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
|
||||
timestep = 0
|
||||
rollout_idx = 0
|
||||
|
||||
while timestep < total_timesteps:
|
||||
# ── Phase de collecte (rollout) ───────────────────────────────────
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
|
||||
action, log_prob, _, value = self.network.get_action_and_value(obs_t)
|
||||
|
||||
action_np = int(action.cpu().item())
|
||||
log_prob_np = float(log_prob.cpu().item())
|
||||
value_np = float(value.cpu().item())
|
||||
|
||||
next_obs, reward, terminated, truncated, _ = env.step(action_np)
|
||||
done = terminated or truncated
|
||||
|
||||
obs_buf[rollout_idx] = obs
|
||||
actions_buf[rollout_idx] = action_np
|
||||
rewards_buf[rollout_idx] = reward
|
||||
dones_buf[rollout_idx] = float(done)
|
||||
values_buf[rollout_idx] = value_np
|
||||
logprobs_buf[rollout_idx] = log_prob_np
|
||||
|
||||
ep_return += reward
|
||||
obs = next_obs
|
||||
timestep += 1
|
||||
rollout_idx += 1
|
||||
|
||||
if done:
|
||||
all_episode_returns.append(ep_return)
|
||||
ep_return = 0.0
|
||||
obs, _ = env.reset()
|
||||
|
||||
# ── Phase de mise à jour PPO ──────────────────────────────────────
|
||||
if rollout_idx == self.n_steps:
|
||||
# Calcul de la valeur du dernier état (bootstrap)
|
||||
with torch.no_grad():
|
||||
last_obs_t = torch.from_numpy(obs).float().unsqueeze(0).to(device)
|
||||
last_value = self.network.get_value(last_obs_t).cpu().numpy()
|
||||
|
||||
# Calcul des avantages GAE et des retours
|
||||
advantages, returns = self._compute_gae(
|
||||
rewards_buf, values_buf, dones_buf, float(last_value)
|
||||
)
|
||||
|
||||
# Conversion en tensors
|
||||
obs_t = torch.from_numpy(obs_buf).float().to(device)
|
||||
actions_t = torch.from_numpy(actions_buf).long().to(device)
|
||||
logprobs_t = torch.from_numpy(logprobs_buf).float().to(device)
|
||||
advs_t = torch.from_numpy(advantages).float().to(device)
|
||||
returns_t = torch.from_numpy(returns).float().to(device)
|
||||
|
||||
# Normalisation des avantages
|
||||
advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
|
||||
|
||||
# n_epochs passes d'optimisation
|
||||
n_samples = self.n_steps
|
||||
policy_losses_ep = []
|
||||
value_losses_ep = []
|
||||
entropy_losses_ep = []
|
||||
|
||||
self.network.train()
|
||||
for _ in range(self.n_epochs):
|
||||
indices = np.random.permutation(n_samples)
|
||||
|
||||
for start in range(0, n_samples, self.batch_size):
|
||||
end = start + self.batch_size
|
||||
batch = indices[start:end]
|
||||
if len(batch) < 2:
|
||||
continue
|
||||
|
||||
# Forward pass
|
||||
_, new_logprob, entropy, new_value = (
|
||||
self.network.get_action_and_value(
|
||||
obs_t[batch], actions_t[batch]
|
||||
)
|
||||
)
|
||||
|
||||
# Ratio de probabilité
|
||||
log_ratio = new_logprob - logprobs_t[batch]
|
||||
ratio = torch.exp(log_ratio)
|
||||
|
||||
# Clip PPO
|
||||
adv_batch = advs_t[batch]
|
||||
pg_loss1 = -adv_batch * ratio
|
||||
pg_loss2 = -adv_batch * torch.clamp(
|
||||
ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps
|
||||
)
|
||||
policy_loss = torch.max(pg_loss1, pg_loss2).mean()
|
||||
|
||||
# Loss valeur (avec clip optionnel)
|
||||
value_loss = nn.functional.mse_loss(new_value, returns_t[batch])
|
||||
|
||||
# Bonus d'entropie (encourage l'exploration)
|
||||
entropy_loss = -entropy.mean()
|
||||
|
||||
# Loss totale
|
||||
total_loss = (
|
||||
policy_loss
|
||||
+ self.value_coef * value_loss
|
||||
+ self.entropy_coef * entropy_loss
|
||||
)
|
||||
|
||||
# Optimisation
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
self.network.parameters(), self.max_grad_norm
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
policy_losses_ep.append(float(policy_loss.item()))
|
||||
value_losses_ep.append(float(value_loss.item()))
|
||||
entropy_losses_ep.append(float(entropy_loss.item()))
|
||||
|
||||
all_policy_losses.extend(policy_losses_ep)
|
||||
all_value_losses.extend(value_losses_ep)
|
||||
all_entropy_losses.extend(entropy_losses_ep)
|
||||
n_updates += 1
|
||||
rollout_idx = 0
|
||||
|
||||
if n_updates % 10 == 0:
|
||||
mean_ret = float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0
|
||||
logger.info(
|
||||
f" Timestep {timestep}/{total_timesteps} | "
|
||||
f"Updates={n_updates} | MeanReturn(20ep)={mean_ret:.4f}"
|
||||
)
|
||||
|
||||
self.network.eval()
|
||||
self.network.to('cpu')
|
||||
self.is_trained = True
|
||||
|
||||
metrics = {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'total_timesteps': total_timesteps,
|
||||
'n_updates': n_updates,
|
||||
'mean_reward': float(np.mean(all_episode_returns)) if all_episode_returns else 0.0,
|
||||
'mean_ep_return': float(np.mean(all_episode_returns[-20:])) if all_episode_returns else 0.0,
|
||||
'n_episodes': len(all_episode_returns),
|
||||
'policy_loss': float(np.mean(all_policy_losses[-100:])) if all_policy_losses else 0.0,
|
||||
'value_loss': float(np.mean(all_value_losses[-100:])) if all_value_losses else 0.0,
|
||||
'entropy_loss': float(np.mean(all_entropy_losses[-100:])) if all_entropy_losses else 0.0,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'hyperparams': {
|
||||
'lr': self.lr,
|
||||
'n_steps': self.n_steps,
|
||||
'n_epochs': self.n_epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'gamma': self.gamma,
|
||||
'gae_lambda': self.gae_lambda,
|
||||
'clip_eps': self.clip_eps,
|
||||
'entropy_coef': self.entropy_coef,
|
||||
'value_coef': self.value_coef,
|
||||
},
|
||||
}
|
||||
self.metadata = metrics
|
||||
|
||||
logger.info(
|
||||
f"Entraînement PPO terminé. "
|
||||
f"MeanReturn={metrics['mean_ep_return']:.4f} | "
|
||||
f"Episodes={metrics['n_episodes']}"
|
||||
)
|
||||
return metrics
|
||||
|
||||
def predict(self, obs: np.ndarray) -> Tuple[int, float]:
|
||||
"""
|
||||
Prédit l'action optimale pour une observation donnée.
|
||||
|
||||
En mode inférence (deterministic=True) : argmax des logits.
|
||||
En mode exploration : sampling de la distribution.
|
||||
|
||||
Args:
|
||||
obs: Vecteur d'observation (obs_dim,) en float32
|
||||
|
||||
Returns:
|
||||
Tuple (action, log_prob) où action ∈ {0, 1, 2}
|
||||
"""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||
return 0, 0.0
|
||||
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||
action, log_prob, _, _ = self.network.get_action_and_value(obs_t)
|
||||
|
||||
return int(action.item()), float(log_prob.item())
|
||||
|
||||
def predict_deterministic(self, obs: np.ndarray) -> int:
|
||||
"""
|
||||
Prédit l'action déterministe (argmax) sans exploration.
|
||||
|
||||
Args:
|
||||
obs: Vecteur d'observation (obs_dim,) en float32
|
||||
|
||||
Returns:
|
||||
Action déterministe (argmax des logits) ∈ {0, 1, 2}
|
||||
"""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||
return 0
|
||||
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||
logits, _ = self.network(obs_t)
|
||||
action = int(logits.argmax(dim=-1).item())
|
||||
|
||||
return action
|
||||
|
||||
def get_action_probas(self, obs: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Retourne les probabilités de chaque action via softmax des logits.
|
||||
|
||||
Args:
|
||||
obs: Vecteur d'observation (obs_dim,) en float32
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (3,) : [P(HOLD), P(LONG), P(SHORT)]
|
||||
"""
|
||||
if not TORCH_AVAILABLE or not self.is_trained or self.network is None:
|
||||
return np.array([1.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
obs_t = torch.from_numpy(obs.astype(np.float32)).unsqueeze(0)
|
||||
logits, _ = self.network(obs_t)
|
||||
probas = torch.softmax(logits, dim=-1).squeeze(0).numpy()
|
||||
|
||||
return probas.astype(np.float32)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Sauvegarde / Chargement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def save(self, symbol: str = 'EURUSD', timeframe: str = '1h') -> Path:
|
||||
"""
|
||||
Sauvegarde le modèle et ses métadonnées sur disque.
|
||||
|
||||
Format :
|
||||
{symbol}_{timeframe}.pt — state_dict + config PyTorch
|
||||
{symbol}_{timeframe}_meta.json — métadonnées JSON
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Path vers le fichier .pt sauvegardé
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch non disponible ou modèle non entraîné
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
if not self.is_trained or self.network is None:
|
||||
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||
|
||||
model_id = f"{symbol}_{timeframe}"
|
||||
model_path = self.MODELS_DIR / f"{model_id}.pt"
|
||||
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
torch.save(
|
||||
{
|
||||
'state_dict': self.network.state_dict(),
|
||||
'config': {
|
||||
'obs_dim': self.obs_dim,
|
||||
'n_actions': self.n_actions,
|
||||
'lr': self.lr,
|
||||
'n_steps': self.n_steps,
|
||||
'n_epochs': self.n_epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'gamma': self.gamma,
|
||||
'gae_lambda': self.gae_lambda,
|
||||
'clip_eps': self.clip_eps,
|
||||
'entropy_coef': self.entropy_coef,
|
||||
'value_coef': self.value_coef,
|
||||
'max_grad_norm': self.max_grad_norm,
|
||||
},
|
||||
'metadata': self.metadata,
|
||||
},
|
||||
model_path
|
||||
)
|
||||
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Modèle PPO sauvegardé : {model_path}")
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'PPOModel':
|
||||
"""
|
||||
Charge un modèle PPO existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance PPOModel prête à prédire
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch non disponible
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle PPO non trouvé : {model_path}")
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
cfg = checkpoint.get('config', {})
|
||||
|
||||
instance = cls(
|
||||
obs_dim = cfg.get('obs_dim', 20),
|
||||
n_actions = cfg.get('n_actions', 3),
|
||||
lr = cfg.get('lr', 3e-4),
|
||||
n_steps = cfg.get('n_steps', 2048),
|
||||
n_epochs = cfg.get('n_epochs', 10),
|
||||
batch_size = cfg.get('batch_size', 64),
|
||||
gamma = cfg.get('gamma', 0.99),
|
||||
gae_lambda = cfg.get('gae_lambda', 0.95),
|
||||
clip_eps = cfg.get('clip_eps', 0.2),
|
||||
entropy_coef = cfg.get('entropy_coef', 0.01),
|
||||
value_coef = cfg.get('value_coef', 0.5),
|
||||
max_grad_norm = cfg.get('max_grad_norm', 0.5),
|
||||
)
|
||||
|
||||
instance.network.load_state_dict(checkpoint['state_dict'])
|
||||
instance.network.eval()
|
||||
instance.is_trained = True
|
||||
instance.metadata = checkpoint.get('metadata', {})
|
||||
|
||||
logger.info(f"Modèle PPO chargé depuis {model_path}")
|
||||
return instance
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Calcul des avantages (GAE)
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _compute_gae(
|
||||
self,
|
||||
rewards: np.ndarray,
|
||||
values: np.ndarray,
|
||||
dones: np.ndarray,
|
||||
last_value: float,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Calcule les avantages généralisés (GAE) et les retours cibles.
|
||||
|
||||
Formule GAE :
|
||||
δt = rt + γ·V(st+1)·(1-dt) - V(st)
|
||||
Ât = δt + γλ·Ât+1·(1-dt)
|
||||
Rt = Ât + V(st)
|
||||
|
||||
Args:
|
||||
rewards: Récompenses du rollout (n_steps,)
|
||||
values: Valeurs estimées par le critic (n_steps,)
|
||||
dones: Indicateurs de fin d'épisode (n_steps,)
|
||||
last_value: Valeur bootstrap du dernier état (scalaire)
|
||||
|
||||
Returns:
|
||||
Tuple (advantages, returns) de forme (n_steps,)
|
||||
"""
|
||||
n = len(rewards)
|
||||
advantages = np.zeros(n, dtype=np.float32)
|
||||
last_gae = 0.0
|
||||
|
||||
for t in reversed(range(n)):
|
||||
if t == n - 1:
|
||||
next_non_terminal = 1.0 - dones[t]
|
||||
next_value = last_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - dones[t]
|
||||
next_value = values[t + 1]
|
||||
|
||||
delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
|
||||
last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
|
||||
advantages[t] = last_gae
|
||||
|
||||
returns = advantages + values
|
||||
return advantages, returns
|
||||
557
src/ml/rl/rl_strategy_model.py
Normal file
557
src/ml/rl/rl_strategy_model.py
Normal file
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
RLStrategyModel — Interface unifiée pour l'agent PPO de trading.
|
||||
|
||||
Interface identique à MLStrategyModel et CNNImageStrategyModel :
|
||||
model = RLStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv)
|
||||
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||
model.save()
|
||||
model.load(symbol, timeframe)
|
||||
model.list_trained_models()
|
||||
model.get_feature_importance() # retourne [] (RL = boîte noire)
|
||||
|
||||
Pipeline d'entraînement :
|
||||
1. Validation des données OHLCV (minimum 500 barres recommandées)
|
||||
2. Création de TradingEnv sur les données d'entraînement
|
||||
3. Entraînement PPO (total_timesteps configurable)
|
||||
4. Évaluation sur holdout (20% temporel final)
|
||||
5. Sauvegarde du modèle PPO + métadonnées JSON
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.ml.rl.trading_env import TradingEnv, GYM_AVAILABLE, N_FEATURES
|
||||
from src.ml.rl.ppo_model import PPOModel, TORCH_AVAILABLE, MODELS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Décodage des actions RL en signaux de trading
|
||||
# Action 0 (HOLD) → signal 0 (NEUTRAL)
|
||||
# Action 1 (LONG) → signal 1 (LONG)
|
||||
# Action 2 (SHORT) → signal -1 (SHORT)
|
||||
ACTION_TO_SIGNAL = {0: 0, 1: 1, 2: -1}
|
||||
|
||||
# Disponibilité globale du module RL
|
||||
RL_AVAILABLE = TORCH_AVAILABLE and GYM_AVAILABLE
|
||||
|
||||
|
||||
class RLStrategyModel:
|
||||
"""
|
||||
Modèle de trading basé sur un agent PPO (Reinforcement Learning).
|
||||
|
||||
L'agent apprend directement à maximiser le PnL via interaction avec
|
||||
un environnement de trading simulé (TradingEnv), sans supervision
|
||||
explicite de labels.
|
||||
|
||||
Avantages par rapport aux modèles supervisés :
|
||||
- Pas besoin de labels manuels (LONG/SHORT/NEUTRAL)
|
||||
- Optimise directement l'objectif de trading (PnL)
|
||||
- Apprend à gérer les positions (durée, timing de sortie)
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h', '15m')
|
||||
total_timesteps: Nombre de timesteps d'entraînement PPO (défaut: 100_000)
|
||||
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
|
||||
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
|
||||
min_confidence: Seuil de confiance minimum pour signaux (défaut: 0.50)
|
||||
train_ratio: Fraction des données pour l'entraînement (défaut: 0.80)
|
||||
initial_capital: Capital initial pour la simulation (défaut: 10_000)
|
||||
lr: Taux d'apprentissage Adam (défaut: 3e-4)
|
||||
n_steps: Transitions par rollout (défaut: 2048)
|
||||
n_epochs: Passes d'optimisation par rollout (défaut: 10)
|
||||
batch_size: Taille des mini-batchs (défaut: 64)
|
||||
"""
|
||||
|
||||
MODELS_DIR = MODELS_DIR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
total_timesteps: int = 100_000,
|
||||
sl_atr_mult: float = 1.0,
|
||||
tp_atr_mult: float = 2.0,
|
||||
min_confidence: float = 0.50,
|
||||
train_ratio: float = 0.80,
|
||||
initial_capital: float = 10_000.0,
|
||||
lr: float = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
n_epochs: int = 10,
|
||||
batch_size: int = 64,
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.total_timesteps = total_timesteps
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.min_confidence = min_confidence
|
||||
self.train_ratio = train_ratio
|
||||
self.initial_capital = initial_capital
|
||||
self.lr = lr
|
||||
self.n_steps = n_steps
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.ppo_model: Optional[PPOModel] = None
|
||||
self.is_trained = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Entraînement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def train(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Entraîne l'agent PPO sur les données OHLCV fournies.
|
||||
|
||||
Étapes :
|
||||
1. Validation des données (minimum 200 barres)
|
||||
2. Split temporel train/holdout (80/20 par défaut)
|
||||
3. Création du TradingEnv sur le jeu d'entraînement
|
||||
4. Entraînement PPO (total_timesteps)
|
||||
5. Évaluation sur le jeu de holdout
|
||||
6. Sauvegarde automatique du modèle
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV (colonnes : open/high/low/close/volume)
|
||||
Minimum 200 barres, idéalement 2000+ pour un bon apprentissage.
|
||||
|
||||
Returns:
|
||||
Dict avec métriques :
|
||||
- total_timesteps, n_episodes, mean_ep_return (entraînement)
|
||||
- eval_total_pnl, eval_return_pct, eval_sharpe (évaluation holdout)
|
||||
- n_samples, trained_at, error (si échec)
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||
if not GYM_AVAILABLE:
|
||||
return {'error': 'gymnasium non disponible — installer gymnasium>=0.26'}
|
||||
|
||||
logger.info(
|
||||
f"Début entraînement RLStrategyModel pour "
|
||||
f"{self.symbol}/{self.timeframe} — total_timesteps={self.total_timesteps}"
|
||||
)
|
||||
|
||||
# 1. Validation et normalisation des données
|
||||
df = data.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
required_cols = {'open', 'high', 'low', 'close', 'volume'}
|
||||
missing = required_cols - set(df.columns)
|
||||
if missing:
|
||||
return {'error': f'Colonnes manquantes : {missing}'}
|
||||
|
||||
df = df.dropna(subset=list(required_cols))
|
||||
|
||||
if len(df) < 200:
|
||||
return {
|
||||
'error': f'Données insuffisantes : {len(df)} barres (minimum 200)'
|
||||
}
|
||||
|
||||
# 2. Split temporel train / holdout (respecte l'ordre chronologique)
|
||||
n_train = int(len(df) * self.train_ratio)
|
||||
df_train = df.iloc[:n_train].reset_index(drop=True)
|
||||
df_holdout = df.iloc[n_train:].reset_index(drop=True)
|
||||
|
||||
logger.info(
|
||||
f" Split : train={len(df_train)} barres, holdout={len(df_holdout)} barres"
|
||||
)
|
||||
|
||||
if len(df_train) < 100:
|
||||
return {'error': f'Données d\'entraînement insuffisantes : {len(df_train)} barres'}
|
||||
|
||||
# 3. Création de l'environnement d'entraînement
|
||||
train_env = TradingEnv(
|
||||
df = df_train,
|
||||
sl_atr_mult = self.sl_atr_mult,
|
||||
tp_atr_mult = self.tp_atr_mult,
|
||||
initial_capital = self.initial_capital,
|
||||
)
|
||||
|
||||
# 4. Initialisation et entraînement PPO
|
||||
self.ppo_model = PPOModel(
|
||||
obs_dim = N_FEATURES,
|
||||
n_actions = 3,
|
||||
lr = self.lr,
|
||||
n_steps = min(self.n_steps, len(df_train) - 30), # Sécurité
|
||||
n_epochs = self.n_epochs,
|
||||
batch_size = self.batch_size,
|
||||
)
|
||||
|
||||
train_metrics = self.ppo_model.train(
|
||||
env = train_env,
|
||||
total_timesteps = self.total_timesteps,
|
||||
symbol = self.symbol,
|
||||
timeframe = self.timeframe,
|
||||
)
|
||||
|
||||
if 'error' in train_metrics:
|
||||
return train_metrics
|
||||
|
||||
self.is_trained = True
|
||||
|
||||
# 5. Évaluation sur holdout
|
||||
eval_metrics = {}
|
||||
if len(df_holdout) >= 50:
|
||||
eval_metrics = self._evaluate_on_holdout(df_holdout)
|
||||
logger.info(
|
||||
f" Évaluation holdout : "
|
||||
f"PnL={eval_metrics.get('eval_total_pnl', 0):.4f}, "
|
||||
f"Return={eval_metrics.get('eval_return_pct', 0):.2%}, "
|
||||
f"Sharpe≈{eval_metrics.get('eval_sharpe', 0):.2f}"
|
||||
)
|
||||
|
||||
# 6. Assemblage des métadonnées
|
||||
self.metadata = {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'n_samples': len(df),
|
||||
'n_train': len(df_train),
|
||||
'n_holdout': len(df_holdout),
|
||||
'total_timesteps': train_metrics.get('total_timesteps', self.total_timesteps),
|
||||
'n_episodes': train_metrics.get('n_episodes', 0),
|
||||
'n_updates': train_metrics.get('n_updates', 0),
|
||||
'mean_ep_return': train_metrics.get('mean_ep_return', 0.0),
|
||||
'policy_loss': train_metrics.get('policy_loss', 0.0),
|
||||
'value_loss': train_metrics.get('value_loss', 0.0),
|
||||
'entropy_loss': train_metrics.get('entropy_loss', 0.0),
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'min_confidence': self.min_confidence,
|
||||
'hyperparams': train_metrics.get('hyperparams', {}),
|
||||
**eval_metrics,
|
||||
}
|
||||
|
||||
# 7. Sauvegarde automatique
|
||||
self.save()
|
||||
|
||||
logger.info(
|
||||
f"Entraînement RL terminé pour {self.symbol}/{self.timeframe} — "
|
||||
f"N_episodes={self.metadata['n_episodes']}, "
|
||||
f"MeanReturn={self.metadata['mean_ep_return']:.4f}"
|
||||
)
|
||||
return self.metadata
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Prédiction
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def predict(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal de trading pour la dernière barre disponible.
|
||||
|
||||
L'agent évalue les 20 dernières barres (fenêtre lookback) et retourne
|
||||
son action déterministe (HOLD/LONG/SHORT) avec les probabilités associées.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV récent (minimum 25 barres pour les indicateurs)
|
||||
|
||||
Returns:
|
||||
Dict : {
|
||||
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL/HOLD),
|
||||
'confidence': float [0..1] — max(P_LONG, P_SHORT),
|
||||
'probas': {'hold': float, 'long': float, 'short': float},
|
||||
'tradeable': bool — confidence >= min_confidence et signal != 0,
|
||||
}
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'PyTorch non disponible'
|
||||
}
|
||||
if not self.is_trained or self.ppo_model is None:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'Modèle non entraîné — appeler train() d\'abord'
|
||||
}
|
||||
|
||||
df = data.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
if len(df) < 25:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': f'Pas assez de données : {len(df)} barres (minimum 25)'
|
||||
}
|
||||
|
||||
try:
|
||||
# Créer un environnement temporaire pour obtenir l'observation
|
||||
temp_env = TradingEnv(
|
||||
df = df,
|
||||
sl_atr_mult = self.sl_atr_mult,
|
||||
tp_atr_mult = self.tp_atr_mult,
|
||||
initial_capital = self.initial_capital,
|
||||
)
|
||||
|
||||
# Positionner l'environnement à la dernière barre
|
||||
temp_env._current_step = len(df) - 1
|
||||
obs = temp_env._get_observation()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': f'Erreur construction observation : {e}'
|
||||
}
|
||||
|
||||
# Prédiction déterministe (argmax des logits)
|
||||
action = self.ppo_model.predict_deterministic(obs)
|
||||
probas = self.ppo_model.get_action_probas(obs) # [P(HOLD), P(LONG), P(SHORT)]
|
||||
|
||||
signal = ACTION_TO_SIGNAL.get(action, 0)
|
||||
confidence = float(max(probas[1], probas[2])) # max(P_LONG, P_SHORT)
|
||||
|
||||
return {
|
||||
'signal': signal,
|
||||
'confidence': confidence,
|
||||
'probas': {
|
||||
'hold': float(probas[0]),
|
||||
'long': float(probas[1]),
|
||||
'short': float(probas[2]),
|
||||
},
|
||||
'tradeable': confidence >= self.min_confidence and signal != 0,
|
||||
}
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Sauvegarde / Chargement
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def save(self) -> Path:
|
||||
"""
|
||||
Sauvegarde le modèle PPO et les métadonnées sur disque.
|
||||
|
||||
Returns:
|
||||
Path vers le fichier .pt sauvegardé
|
||||
|
||||
Raises:
|
||||
RuntimeError si le modèle n'est pas entraîné
|
||||
"""
|
||||
if not self.is_trained or self.ppo_model is None:
|
||||
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||
|
||||
# Sauvegarde du réseau PPO
|
||||
model_path = self.ppo_model.save(self.symbol, self.timeframe)
|
||||
|
||||
# Sauvegarde des métadonnées complètes (incluant les params RLStrategyModel)
|
||||
meta = self.metadata.copy()
|
||||
meta.update({
|
||||
'rl_config': {
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'min_confidence': self.min_confidence,
|
||||
'train_ratio': self.train_ratio,
|
||||
'initial_capital': self.initial_capital,
|
||||
'total_timesteps': self.total_timesteps,
|
||||
}
|
||||
})
|
||||
|
||||
meta_path = self.MODELS_DIR / f"{self.symbol}_{self.timeframe}_meta.json"
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(meta, f, indent=2, default=str)
|
||||
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'RLStrategyModel':
|
||||
"""
|
||||
Charge un modèle RL existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance RLStrategyModel prête à prédire
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch ou gymnasium non disponible
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
meta_path = MODELS_DIR / f"{symbol}_{timeframe}_meta.json"
|
||||
rl_config = {}
|
||||
if meta_path.exists():
|
||||
try:
|
||||
with open(meta_path) as f:
|
||||
saved_meta = json.load(f)
|
||||
rl_config = saved_meta.get('rl_config', {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
instance = cls(
|
||||
symbol = symbol,
|
||||
timeframe = timeframe,
|
||||
sl_atr_mult = rl_config.get('sl_atr_mult', 1.0),
|
||||
tp_atr_mult = rl_config.get('tp_atr_mult', 2.0),
|
||||
min_confidence = rl_config.get('min_confidence', 0.50),
|
||||
train_ratio = rl_config.get('train_ratio', 0.80),
|
||||
initial_capital = rl_config.get('initial_capital', 10_000.0),
|
||||
total_timesteps = rl_config.get('total_timesteps', 100_000),
|
||||
)
|
||||
|
||||
# Chargement du modèle PPO
|
||||
instance.ppo_model = PPOModel.load(symbol, timeframe)
|
||||
instance.is_trained = True
|
||||
|
||||
if meta_path.exists():
|
||||
try:
|
||||
with open(meta_path) as f:
|
||||
instance.metadata = json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Modèle RL chargé depuis {MODELS_DIR}/{symbol}_{timeframe}.pt")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def list_trained_models() -> List[Dict]:
|
||||
"""
|
||||
Retourne la liste des modèles RL entraînés disponibles.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec symbol, timeframe, trained_at, n_samples,
|
||||
mean_ep_return, eval_return_pct
|
||||
"""
|
||||
if not MODELS_DIR.exists():
|
||||
return []
|
||||
|
||||
models = []
|
||||
for f in MODELS_DIR.glob("*_meta.json"):
|
||||
try:
|
||||
with open(f) as fp:
|
||||
meta = json.load(fp)
|
||||
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
|
||||
parts = stem.split('_', 1)
|
||||
models.append({
|
||||
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||
'trained_at': meta.get('trained_at', '?'),
|
||||
'n_samples': meta.get('n_samples', 0),
|
||||
'total_timesteps': meta.get('total_timesteps', 0),
|
||||
'n_episodes': meta.get('n_episodes', 0),
|
||||
'mean_ep_return': meta.get('mean_ep_return', 0.0),
|
||||
'eval_return_pct': meta.get('eval_return_pct', 0.0),
|
||||
'eval_sharpe': meta.get('eval_sharpe', 0.0),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return models
|
||||
|
||||
def get_feature_importance(self) -> List[Dict]:
|
||||
"""
|
||||
Retourne l'importance des features.
|
||||
|
||||
Note : Les agents RL (PPO) sont des boîtes noires — pas d'importance
|
||||
de feature interprétable. Retourne les noms des 20 features pour
|
||||
information, sans score numérique.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec feature_name et description
|
||||
"""
|
||||
feature_names = [
|
||||
{'rank': 1, 'feature': 'body_ratio', 'description': '(close-open)/ATR — corps de la bougie'},
|
||||
{'rank': 2, 'feature': 'amplitude', 'description': '(high-low)/ATR — amplitude de la bougie'},
|
||||
{'rank': 3, 'feature': 'momentum_1', 'description': '(close-close[-1])/ATR — momentum 1 barre'},
|
||||
{'rank': 4, 'feature': 'momentum_5', 'description': '(close-close[-5])/ATR — momentum 5 barres'},
|
||||
{'rank': 5, 'feature': 'rsi_norm', 'description': 'RSI(14)/100 — RSI normalisé'},
|
||||
{'rank': 6, 'feature': 'bb_position', 'description': '(close-BB_upper)/ATR — position vs Bollinger'},
|
||||
{'rank': 7, 'feature': 'bb_width', 'description': '(BB_upper-BB_lower)/ATR — largeur bandes'},
|
||||
{'rank': 8, 'feature': 'macd_norm', 'description': 'MACD/ATR — MACD normalisé'},
|
||||
{'rank': 9, 'feature': 'macd_signal_norm', 'description': 'Signal MACD/ATR — signal MACD normalisé'},
|
||||
{'rank': 10, 'feature': 'vol_ratio', 'description': 'Volume/MA20(Volume) — ratio volume'},
|
||||
{'rank': 11, 'feature': 'position', 'description': 'Position courante (-1=short, 0=flat, 1=long)'},
|
||||
{'rank': 12, 'feature': 'unrealized_pnl', 'description': 'PnL_non_réalisé/ATR — gain/perte courant'},
|
||||
{'rank': 13, 'feature': 'bars_in_trade', 'description': 'Durée du trade normalisée'},
|
||||
{'rank': 14, 'feature': 'hour', 'description': 'Heure de la journée normalisée'},
|
||||
{'rank': 15, 'feature': 'day_of_week', 'description': 'Jour de semaine normalisé'},
|
||||
{'rank': 16, 'feature': 'ema9_dist', 'description': '(close-EMA9)/ATR — distance EMA9'},
|
||||
{'rank': 17, 'feature': 'ema21_dist', 'description': '(close-EMA21)/ATR — distance EMA21'},
|
||||
{'rank': 18, 'feature': 'ema9_ema21_cross', 'description': '(EMA9-EMA21)/ATR — croisement EMA court'},
|
||||
{'rank': 19, 'feature': 'ema50_dist', 'description': '(close-EMA50)/ATR — distance EMA50'},
|
||||
{'rank': 20, 'feature': 'ema200_dist', 'description': '(close-EMA200)/ATR — distance EMA200'},
|
||||
]
|
||||
return feature_names
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Évaluation interne
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _evaluate_on_holdout(self, df_holdout: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Évalue l'agent entraîné sur le jeu de holdout.
|
||||
|
||||
Lance un épisode complet sur les données de holdout en mode déterministe
|
||||
(argmax des logits, pas de sampling). Retourne les métriques de performance.
|
||||
|
||||
Args:
|
||||
df_holdout: DataFrame OHLCV du jeu de holdout
|
||||
|
||||
Returns:
|
||||
Dict avec eval_total_pnl, eval_return_pct, eval_sharpe,
|
||||
eval_n_trades, eval_win_rate
|
||||
"""
|
||||
try:
|
||||
eval_env = TradingEnv(
|
||||
df = df_holdout,
|
||||
sl_atr_mult = self.sl_atr_mult,
|
||||
tp_atr_mult = self.tp_atr_mult,
|
||||
initial_capital = self.initial_capital,
|
||||
)
|
||||
|
||||
obs, _ = eval_env.reset()
|
||||
done = False
|
||||
rewards_hist = []
|
||||
n_trades = 0
|
||||
win_trades = 0
|
||||
|
||||
while not done:
|
||||
action = self.ppo_model.predict_deterministic(obs)
|
||||
obs, reward, terminated, truncated, info = eval_env.step(action)
|
||||
done = terminated or truncated
|
||||
|
||||
if abs(reward) > 0.001: # Fermeture de position
|
||||
n_trades += 1
|
||||
if reward > 0:
|
||||
win_trades += 1
|
||||
|
||||
rewards_hist.append(reward)
|
||||
|
||||
stats = eval_env.get_episode_stats()
|
||||
win_rate = win_trades / max(n_trades, 1)
|
||||
rewards_arr = np.array(rewards_hist)
|
||||
sharpe = (
|
||||
float(rewards_arr.mean() / (rewards_arr.std() + 1e-8) * np.sqrt(252))
|
||||
if len(rewards_arr) > 1 else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
'eval_total_pnl': float(stats.get('total_pnl', 0.0)),
|
||||
'eval_return_pct': float(stats.get('return_pct', 0.0)),
|
||||
'eval_sharpe': float(sharpe),
|
||||
'eval_n_trades': n_trades,
|
||||
'eval_win_rate': float(win_rate),
|
||||
'eval_n_steps': int(stats.get('n_steps', 0)),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Évaluation holdout échouée : {e}")
|
||||
return {
|
||||
'eval_total_pnl': 0.0,
|
||||
'eval_return_pct': 0.0,
|
||||
'eval_sharpe': 0.0,
|
||||
'eval_n_trades': 0,
|
||||
'eval_win_rate': 0.0,
|
||||
}
|
||||
524
src/ml/rl/trading_env.py
Normal file
524
src/ml/rl/trading_env.py
Normal file
@@ -0,0 +1,524 @@
|
||||
"""
|
||||
TradingEnv — Environnement Gymnasium pour l'agent RL de trading.
|
||||
|
||||
Conforme à l'API gymnasium (reset/step) et adapté aux marchés forex.
|
||||
|
||||
Espace d'observation : 20 features normalisées (fenêtre glissante de 20 barres)
|
||||
Espace d'action : Discrete(3) → {0=HOLD, 1=LONG, 2=SHORT}
|
||||
Récompense : PnL réalisé + pénalité drawdown + bonus Sharpe
|
||||
|
||||
L'environnement gère :
|
||||
- Les transitions de position (flat → long, flat → short, inversions)
|
||||
- Le stop-loss / take-profit automatiques basés sur l'ATR
|
||||
- La normalisation des observations par l'ATR
|
||||
- La fenêtre glissante de 20 barres (lookback)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Imports conditionnels gymnasium / gym
|
||||
try:
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
GYM_AVAILABLE = True
|
||||
GYM_MODULE = 'gymnasium'
|
||||
except ImportError:
|
||||
try:
|
||||
import gym
|
||||
from gym import spaces
|
||||
GYM_AVAILABLE = True
|
||||
GYM_MODULE = 'gym'
|
||||
except ImportError:
|
||||
gym = None
|
||||
spaces = None
|
||||
GYM_AVAILABLE = False
|
||||
GYM_MODULE = None
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Actions
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
ACTION_HOLD = 0
|
||||
ACTION_LONG = 1
|
||||
ACTION_SHORT = 2
|
||||
|
||||
# Nombre de features dans le vecteur d'observation
|
||||
N_FEATURES = 20
|
||||
|
||||
# Fenêtre lookback (nombre de barres passées incluses dans l'observation)
|
||||
LOOKBACK = 20
|
||||
|
||||
|
||||
class TradingEnv:
|
||||
"""
|
||||
Environnement de trading conforme à l'API gymnasium pour l'agent PPO.
|
||||
|
||||
Chaque step() représente la décision de trading à la fermeture d'une bougie.
|
||||
L'agent reçoit 20 features normalisées décrivant le contexte de marché et
|
||||
sa position courante, puis choisit HOLD / LONG / SHORT.
|
||||
|
||||
Gestion des positions :
|
||||
- Une seule position à la fois (pas de pyramiding)
|
||||
- Inversion directe possible (SHORT → LONG sans passer par FLAT)
|
||||
- SL/TP ATR-based (sl_atr_mult×ATR et tp_atr_mult×ATR)
|
||||
|
||||
Récompense :
|
||||
- PnL réalisé à la clôture de position (en multiple d'ATR)
|
||||
- Pénalité proportionnelle au drawdown courant
|
||||
- Pas de bonus pour maintenir une position (évite le surtrading)
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV (colonnes minuscules : open/high/low/close/volume)
|
||||
sl_atr_mult: Multiplicateur ATR pour le stop-loss (défaut: 1.0)
|
||||
tp_atr_mult: Multiplicateur ATR pour le take-profit (défaut: 2.0)
|
||||
atr_period: Période de calcul de l'ATR (défaut: 14)
|
||||
initial_capital: Capital initial (pour calcul drawdown) (défaut: 10000)
|
||||
drawdown_penalty: Coefficient de pénalité drawdown (défaut: 0.1)
|
||||
"""
|
||||
|
||||
metadata = {'render_modes': []}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
sl_atr_mult: float = 1.0,
|
||||
tp_atr_mult: float = 2.0,
|
||||
atr_period: int = 14,
|
||||
initial_capital: float = 10_000.0,
|
||||
drawdown_penalty: float = 0.1,
|
||||
):
|
||||
if not GYM_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"gymnasium ou gym requis — installer gymnasium>=0.26 ou gym>=0.21"
|
||||
)
|
||||
|
||||
self.df = df.copy().reset_index(drop=True)
|
||||
self.df.columns = [c.lower() for c in self.df.columns]
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.atr_period = atr_period
|
||||
self.initial_capital = initial_capital
|
||||
self.drawdown_penalty = drawdown_penalty
|
||||
|
||||
# Calcul de l'ATR sur toute la série (optimisation : évite le recalcul dans step)
|
||||
self._atr_series = self._compute_atr_series()
|
||||
|
||||
# Calcul des EMAs sur toute la série
|
||||
self._ema9 = self.df['close'].ewm(span=9, adjust=False).mean()
|
||||
self._ema21 = self.df['close'].ewm(span=21, adjust=False).mean()
|
||||
self._ema50 = self.df['close'].ewm(span=50, adjust=False).mean()
|
||||
self._ema200 = self.df['close'].ewm(span=200, adjust=False).mean()
|
||||
|
||||
# Calcul des Bandes de Bollinger (20, 2σ)
|
||||
bb_ma = self.df['close'].rolling(20).mean()
|
||||
bb_std = self.df['close'].rolling(20).std()
|
||||
self._bb_upper = bb_ma + 2 * bb_std
|
||||
self._bb_lower = bb_ma - 2 * bb_std
|
||||
|
||||
# Calcul du RSI(14)
|
||||
self._rsi = self._compute_rsi(period=14)
|
||||
|
||||
# Calcul du MACD (12, 26, 9)
|
||||
ema12 = self.df['close'].ewm(span=12, adjust=False).mean()
|
||||
ema26 = self.df['close'].ewm(span=26, adjust=False).mean()
|
||||
self._macd = ema12 - ema26
|
||||
self._macd_sig = self._macd.ewm(span=9, adjust=False).mean()
|
||||
|
||||
# Volume MA(20) pour normalisation
|
||||
self._vol_ma20 = self.df['volume'].rolling(20).mean().fillna(
|
||||
self.df['volume'].mean()
|
||||
)
|
||||
|
||||
# Espaces gymnasium
|
||||
self.observation_space = spaces.Box(
|
||||
low = -10.0,
|
||||
high = 10.0,
|
||||
shape = (N_FEATURES,),
|
||||
dtype = np.float32,
|
||||
)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
|
||||
# État interne (initialisé dans reset())
|
||||
self._current_step = LOOKBACK
|
||||
self._position = 0 # -1=short, 0=flat, 1=long
|
||||
self._entry_price = 0.0
|
||||
self._entry_atr = 0.0
|
||||
self._bars_in_trade = 0
|
||||
self._capital = initial_capital
|
||||
self._peak_capital = initial_capital
|
||||
self._total_pnl = 0.0
|
||||
self._pnl_history = [] # Pour calcul Sharpe en fin d'épisode
|
||||
self._done = False
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Interface gymnasium
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
"""
|
||||
Réinitialise l'environnement au début d'un épisode.
|
||||
|
||||
Args:
|
||||
seed: Graine aléatoire (ignorée ici, données déterministes)
|
||||
options: Options additionnelles (non utilisées)
|
||||
|
||||
Returns:
|
||||
Tuple (observation, info) conforme gymnasium
|
||||
"""
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
self._current_step = LOOKBACK
|
||||
self._position = 0
|
||||
self._entry_price = 0.0
|
||||
self._entry_atr = 0.0
|
||||
self._bars_in_trade = 0
|
||||
self._capital = self.initial_capital
|
||||
self._peak_capital = self.initial_capital
|
||||
self._total_pnl = 0.0
|
||||
self._pnl_history = []
|
||||
self._done = False
|
||||
|
||||
obs = self._get_observation()
|
||||
info = {}
|
||||
return obs, info
|
||||
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, dict]:
|
||||
"""
|
||||
Exécute une action et retourne la transition.
|
||||
|
||||
Args:
|
||||
action: 0=HOLD, 1=LONG, 2=SHORT
|
||||
|
||||
Returns:
|
||||
Tuple (observation, reward, terminated, truncated, info) conforme gymnasium
|
||||
"""
|
||||
if self._done:
|
||||
obs, info = self.reset()
|
||||
return obs, 0.0, True, False, info
|
||||
|
||||
reward = 0.0
|
||||
info = {}
|
||||
|
||||
current_close = float(self.df['close'].iloc[self._current_step])
|
||||
current_atr = float(self._atr_series.iloc[self._current_step])
|
||||
if current_atr <= 0:
|
||||
current_atr = float(self.df['close'].iloc[self._current_step]) * 0.001
|
||||
|
||||
# ── Vérification SL/TP avant d'appliquer la nouvelle action ──────────
|
||||
if self._position != 0:
|
||||
reward += self._check_sl_tp(current_close, current_atr)
|
||||
|
||||
# ── Application de la nouvelle action ─────────────────────────────────
|
||||
desired_position = self._action_to_position(action)
|
||||
|
||||
if desired_position != self._position:
|
||||
# Fermeture de la position courante (si ouverte)
|
||||
if self._position != 0:
|
||||
close_reward = self._close_position(current_close, current_atr)
|
||||
reward += close_reward
|
||||
|
||||
# Ouverture d'une nouvelle position (si pas HOLD)
|
||||
if desired_position != 0:
|
||||
self._open_position(desired_position, current_close, current_atr)
|
||||
else:
|
||||
# Même position : incrémenter le compteur de barres
|
||||
if self._position != 0:
|
||||
self._bars_in_trade += 1
|
||||
|
||||
# ── Pénalité drawdown ─────────────────────────────────────────────────
|
||||
if self._capital > self._peak_capital:
|
||||
self._peak_capital = self._capital
|
||||
drawdown = (self._peak_capital - self._capital) / max(self._peak_capital, 1.0)
|
||||
if drawdown > 0:
|
||||
reward -= self.drawdown_penalty * drawdown
|
||||
|
||||
# ── Avance d'une barre ────────────────────────────────────────────────
|
||||
self._current_step += 1
|
||||
terminated = self._current_step >= len(self.df) - 1
|
||||
truncated = False
|
||||
|
||||
if terminated:
|
||||
# Fermeture forcée de la position à la fin de l'épisode
|
||||
if self._position != 0:
|
||||
final_close = float(self.df['close'].iloc[-1])
|
||||
final_atr = float(self._atr_series.iloc[-1])
|
||||
reward += self._close_position(final_close, final_atr)
|
||||
self._done = True
|
||||
|
||||
obs = self._get_observation()
|
||||
|
||||
# Sauvegarde de la récompense pour calcul Sharpe
|
||||
self._pnl_history.append(reward)
|
||||
|
||||
info = {
|
||||
'position': self._position,
|
||||
'capital': self._capital,
|
||||
'drawdown': drawdown,
|
||||
'bars_in_trade': self._bars_in_trade,
|
||||
'step': self._current_step,
|
||||
}
|
||||
|
||||
return obs, float(reward), terminated, truncated, info
|
||||
|
||||
def render(self):
|
||||
"""Affichage minimal de l'état courant."""
|
||||
pos_str = {0: 'FLAT', 1: 'LONG', -1: 'SHORT'}.get(self._position, '?')
|
||||
logger.debug(
|
||||
f"Step={self._current_step} | Pos={pos_str} | "
|
||||
f"Capital={self._capital:.2f} | PnL={self._total_pnl:.4f}"
|
||||
)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Observation
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
"""
|
||||
Construit le vecteur d'observation de 20 features normalisées.
|
||||
|
||||
Toutes les features de prix sont normalisées par l'ATR courant pour
|
||||
rendre l'observation invariante à l'échelle du prix.
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (20,) avec dtype float32
|
||||
"""
|
||||
i = self._current_step
|
||||
# Sécurité : ne pas dépasser les bornes
|
||||
i = max(LOOKBACK, min(i, len(self.df) - 1))
|
||||
|
||||
close = float(self.df['close'].iloc[i])
|
||||
open_ = float(self.df['open'].iloc[i])
|
||||
high = float(self.df['high'].iloc[i])
|
||||
low = float(self.df['low'].iloc[i])
|
||||
vol = float(self.df['volume'].iloc[i])
|
||||
|
||||
atr = float(self._atr_series.iloc[i])
|
||||
if atr <= 0:
|
||||
atr = close * 0.001
|
||||
|
||||
# Close i-1 et i-5 pour le momentum
|
||||
close_1 = float(self.df['close'].iloc[max(0, i - 1)])
|
||||
close_5 = float(self.df['close'].iloc[max(0, i - 5)])
|
||||
|
||||
rsi = float(self._rsi.iloc[i]) if not np.isnan(self._rsi.iloc[i]) else 50.0
|
||||
bb_upper = float(self._bb_upper.iloc[i]) if not np.isnan(self._bb_upper.iloc[i]) else close
|
||||
bb_lower = float(self._bb_lower.iloc[i]) if not np.isnan(self._bb_lower.iloc[i]) else close
|
||||
macd = float(self._macd.iloc[i]) if not np.isnan(self._macd.iloc[i]) else 0.0
|
||||
macd_signal = float(self._macd_sig.iloc[i]) if not np.isnan(self._macd_sig.iloc[i]) else 0.0
|
||||
vol_ma20 = float(self._vol_ma20.iloc[i])
|
||||
ema9 = float(self._ema9.iloc[i])
|
||||
ema21 = float(self._ema21.iloc[i])
|
||||
ema50 = float(self._ema50.iloc[i])
|
||||
ema200 = float(self._ema200.iloc[i])
|
||||
|
||||
vol_ratio = vol / max(vol_ma20, 1.0)
|
||||
|
||||
# PnL non-réalisé courant
|
||||
if self._position != 0 and self._entry_price > 0:
|
||||
unrealized = self._position * (close - self._entry_price)
|
||||
else:
|
||||
unrealized = 0.0
|
||||
|
||||
features = np.array([
|
||||
# Prix relatifs normalisés par ATR
|
||||
(close - open_) / atr, # 0 : corps de la bougie
|
||||
(high - low) / atr, # 1 : amplitude bougie (volatilité relative)
|
||||
(close - close_1) / atr, # 2 : momentum 1 barre
|
||||
(close - close_5) / atr, # 3 : momentum 5 barres
|
||||
# Indicateurs techniques normalisés
|
||||
rsi / 100.0, # 4 : RSI [0..1]
|
||||
(close - bb_upper) / atr, # 5 : position vs BB haute
|
||||
(bb_upper - bb_lower) / atr, # 6 : largeur des bandes de Bollinger
|
||||
macd / atr, # 7 : MACD normalisé
|
||||
macd_signal / atr, # 8 : signal MACD normalisé
|
||||
# Volume
|
||||
np.clip(vol_ratio, 0.0, 5.0), # 9 : ratio volume / MA20 (plafonné à 5)
|
||||
# Position et état du trade
|
||||
float(self._position), # 10: position courante (-1, 0, 1)
|
||||
unrealized / atr, # 11: PnL non-réalisé en ATR
|
||||
min(self._bars_in_trade / 50.0, 1.0), # 12: durée du trade (normalisée)
|
||||
# Temporel (si index datetime)
|
||||
self._get_hour(i), # 13: heure normalisée [0..1]
|
||||
self._get_dow(i), # 14: jour de semaine [0..1]
|
||||
# Moyennes mobiles (distance close - EMA, normalisée par ATR)
|
||||
(close - ema9) / atr, # 15: distance EMA9
|
||||
(close - ema21) / atr, # 16: distance EMA21
|
||||
(ema9 - ema21) / atr, # 17: croisement EMA9/21
|
||||
(close - ema50) / atr, # 18: distance EMA50
|
||||
(close - ema200) / atr, # 19: distance EMA200
|
||||
], dtype=np.float32)
|
||||
|
||||
# Clip pour éviter les valeurs extrêmes (protection robustesse)
|
||||
features = np.clip(features, -10.0, 10.0)
|
||||
# Remplacement des NaN résiduels
|
||||
features = np.nan_to_num(features, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||
|
||||
return features
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Gestion des positions
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _open_position(self, direction: int, price: float, atr: float) -> None:
|
||||
"""
|
||||
Ouvre une position dans la direction donnée.
|
||||
|
||||
Args:
|
||||
direction: 1=LONG, -1=SHORT
|
||||
price: Prix d'entrée (close de la barre courante)
|
||||
atr: ATR courant pour le calcul SL/TP
|
||||
"""
|
||||
self._position = direction
|
||||
self._entry_price = price
|
||||
self._entry_atr = atr
|
||||
self._bars_in_trade = 0
|
||||
|
||||
def _close_position(self, price: float, atr: float) -> float:
|
||||
"""
|
||||
Ferme la position courante et calcule la récompense.
|
||||
|
||||
La récompense est le PnL en multiple d'ATR (normalisé et sans unité).
|
||||
Cette normalisation rend la récompense comparable entre différents
|
||||
symboles et périodes de volatilité.
|
||||
|
||||
Args:
|
||||
price: Prix de sortie
|
||||
atr: ATR courant (pour normalisation)
|
||||
|
||||
Returns:
|
||||
Récompense (PnL normalisé par ATR)
|
||||
"""
|
||||
if self._position == 0 or self._entry_price <= 0:
|
||||
self._position = 0
|
||||
self._bars_in_trade = 0
|
||||
return 0.0
|
||||
|
||||
raw_pnl = self._position * (price - self._entry_price)
|
||||
|
||||
# Normalisation par l'ATR d'entrée pour une récompense sans unité
|
||||
entry_atr = max(self._entry_atr, price * 0.0001)
|
||||
reward = raw_pnl / entry_atr
|
||||
|
||||
# Mise à jour du capital (simulation simplifiée avec 1 lot fixe)
|
||||
self._capital += raw_pnl
|
||||
self._total_pnl += raw_pnl
|
||||
|
||||
self._position = 0
|
||||
self._entry_price = 0.0
|
||||
self._bars_in_trade = 0
|
||||
|
||||
return float(reward)
|
||||
|
||||
def _check_sl_tp(self, current_close: float, current_atr: float) -> float:
|
||||
"""
|
||||
Vérifie si le SL ou TP est atteint pour la position courante.
|
||||
|
||||
Utilise l'ATR d'entrée pour les niveaux SL/TP (stabilité).
|
||||
Si le SL ou TP est touché, la position est fermée.
|
||||
|
||||
Args:
|
||||
current_close: Prix de clôture de la barre courante
|
||||
current_atr: ATR courant
|
||||
|
||||
Returns:
|
||||
Récompense si fermeture forcée, 0.0 sinon
|
||||
"""
|
||||
if self._position == 0 or self._entry_price <= 0:
|
||||
return 0.0
|
||||
|
||||
entry_atr = max(self._entry_atr, self._entry_price * 0.0001)
|
||||
sl_dist = self.sl_atr_mult * entry_atr
|
||||
tp_dist = self.tp_atr_mult * entry_atr
|
||||
|
||||
if self._position == 1: # LONG
|
||||
sl_level = self._entry_price - sl_dist
|
||||
tp_level = self._entry_price + tp_dist
|
||||
if current_close <= sl_level or current_close >= tp_level:
|
||||
return self._close_position(current_close, current_atr)
|
||||
|
||||
elif self._position == -1: # SHORT
|
||||
sl_level = self._entry_price + sl_dist
|
||||
tp_level = self._entry_price - tp_dist
|
||||
if current_close >= sl_level or current_close <= tp_level:
|
||||
return self._close_position(current_close, current_atr)
|
||||
|
||||
return 0.0
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Helpers
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _action_to_position(action: int) -> int:
|
||||
"""Convertit l'action discrète en direction de position."""
|
||||
return {ACTION_HOLD: 0, ACTION_LONG: 1, ACTION_SHORT: -1}.get(action, 0)
|
||||
|
||||
def _get_hour(self, i: int) -> float:
|
||||
"""Retourne l'heure normalisée [0..1] depuis l'index, ou 0.5 si non datetime."""
|
||||
try:
|
||||
idx = self.df.index[i]
|
||||
if hasattr(idx, 'hour'):
|
||||
return float(idx.hour) / 24.0
|
||||
except Exception:
|
||||
pass
|
||||
return 0.5
|
||||
|
||||
def _get_dow(self, i: int) -> float:
|
||||
"""Retourne le jour de semaine normalisé [0..1] depuis l'index, ou 0.5 si non datetime."""
|
||||
try:
|
||||
idx = self.df.index[i]
|
||||
if hasattr(idx, 'dayofweek'):
|
||||
return float(idx.dayofweek) / 5.0
|
||||
except Exception:
|
||||
pass
|
||||
return 0.5
|
||||
|
||||
def _compute_atr_series(self) -> pd.Series:
|
||||
"""Calcule l'ATR(14) sur toute la série OHLCV."""
|
||||
h = self.df['high']
|
||||
l = self.df['low']
|
||||
pc = self.df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.ewm(span=self.atr_period, adjust=False).mean()
|
||||
# Remplir les premières valeurs NaN par la plage high-low
|
||||
atr = atr.fillna(h - l)
|
||||
return atr
|
||||
|
||||
def _compute_rsi(self, period: int = 14) -> pd.Series:
|
||||
"""Calcule le RSI sur toute la série de closes."""
|
||||
delta = self.df['close'].diff()
|
||||
gain = delta.clip(lower=0).ewm(span=period, adjust=False).mean()
|
||||
loss = (-delta.clip(upper=0)).ewm(span=period, adjust=False).mean()
|
||||
rs = gain / loss.replace(0, np.nan)
|
||||
rsi = 100.0 - (100.0 / (1.0 + rs))
|
||||
return rsi.fillna(50.0)
|
||||
|
||||
def get_episode_stats(self) -> dict:
|
||||
"""
|
||||
Retourne les statistiques de l'épisode courant.
|
||||
|
||||
Utile pour le logging et l'évaluation du modèle PPO.
|
||||
|
||||
Returns:
|
||||
Dict avec total_pnl, n_steps, capital, Sharpe approximatif
|
||||
"""
|
||||
rewards = np.array(self._pnl_history)
|
||||
if len(rewards) > 1 and rewards.std() > 0:
|
||||
sharpe = float(rewards.mean() / rewards.std() * np.sqrt(252))
|
||||
else:
|
||||
sharpe = 0.0
|
||||
|
||||
return {
|
||||
'total_pnl': self._total_pnl,
|
||||
'capital': self._capital,
|
||||
'return_pct': (self._capital - self.initial_capital) / self.initial_capital,
|
||||
'n_steps': self._current_step,
|
||||
'sharpe_approx': sharpe,
|
||||
}
|
||||
3
src/strategies/cnn_driven/__init__.py
Normal file
3
src/strategies/cnn_driven/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .cnn_strategy import CNNDrivenStrategy
|
||||
|
||||
__all__ = ['CNNDrivenStrategy']
|
||||
249
src/strategies/cnn_driven/cnn_strategy.py
Normal file
249
src/strategies/cnn_driven/cnn_strategy.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
CNN-Driven Strategy — Stratégie pilotée par un réseau convolutif 1D.
|
||||
|
||||
Cette stratégie utilise un CNN 1D entraîné sur des séquences OHLCV brutes
|
||||
pour détecter des patterns visuels dans les bougies (double bottom, squeeze
|
||||
Bollinger, alignements, etc.) sans features pré-calculées.
|
||||
|
||||
Fonctionnement :
|
||||
1. Le modèle CNN est chargé depuis le disque (entraîné via POST /trading/train-cnn)
|
||||
2. À chaque barre, la séquence OHLCV récente est passée au CNN
|
||||
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
|
||||
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||
|
||||
Intégration :
|
||||
- Compatible avec StrategyEngine (même interface que ScalpingStrategy / MLDrivenStrategy)
|
||||
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||
- Le RiskManager applique les mêmes contrôles que pour les stratégies classiques
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||
|
||||
try:
|
||||
from src.ml.cnn import CNNStrategyModel
|
||||
CNN_AVAILABLE = True
|
||||
except ImportError:
|
||||
CNN_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CNNDrivenStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading pilotée par un CNN 1D pré-entraîné.
|
||||
|
||||
Le modèle apprend directement les patterns visuels des bougies :
|
||||
- Double bottom / double top
|
||||
- Squeeze Bollinger + expansion
|
||||
- Alignements de moyennes mobiles
|
||||
- Patterns chandeliers complexes
|
||||
- Structures de prix multi-barres
|
||||
|
||||
Args:
|
||||
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||
|
||||
Config keys supplémentaires (optionnelles) :
|
||||
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||
seq_len: Longueur de séquence d'entrée (défaut: 64)
|
||||
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'cnn_driven'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.min_confidence = config.get('min_confidence', 0.55)
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
self.seq_len = config.get('seq_len', 64)
|
||||
|
||||
self.cnn_model: Optional['CNNStrategyModel'] = None
|
||||
|
||||
if not CNN_AVAILABLE:
|
||||
logger.warning("CNN non disponible (PyTorch requis)")
|
||||
return
|
||||
|
||||
# Tentative de chargement automatique du modèle existant
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_model()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# -------------------------------------------------------------------------
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal de trading via le modèle CNN.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Signal si le modèle est confiant, None sinon
|
||||
"""
|
||||
if not CNN_AVAILABLE:
|
||||
logger.debug("CNN Strategy : PyTorch non disponible, aucun signal")
|
||||
return None
|
||||
|
||||
if self.cnn_model is None or not self.cnn_model.is_trained:
|
||||
logger.debug("CNN Strategy : modèle non chargé, aucun signal")
|
||||
return None
|
||||
|
||||
if len(market_data) < self.seq_len:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.cnn_model.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol = self.symbol,
|
||||
direction = direction,
|
||||
entry_price = last_close,
|
||||
stop_loss = stop_loss,
|
||||
take_profit = take_profit,
|
||||
confidence = confidence,
|
||||
timestamp = datetime.now(timezone.utc),
|
||||
strategy = self.STRATEGY_NAME,
|
||||
metadata = {
|
||||
'probas': result.get('probas', {}),
|
||||
'seq_len': self.seq_len,
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"CNN Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — le CNN travaille sur les séquences brutes."""
|
||||
return data
|
||||
|
||||
def update_params(self, params: Dict) -> None:
|
||||
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||
if 'min_confidence' in params:
|
||||
self.min_confidence = params['min_confidence']
|
||||
if self.cnn_model:
|
||||
self.cnn_model.min_confidence = params['min_confidence']
|
||||
if 'tp_atr_mult' in params:
|
||||
self.tp_atr_mult = params['tp_atr_mult']
|
||||
if 'sl_atr_mult' in params:
|
||||
self.sl_atr_mult = params['sl_atr_mult']
|
||||
if 'seq_len' in params:
|
||||
self.seq_len = params['seq_len']
|
||||
logger.info(f"CNN Strategy params mis à jour : {params}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Gestion du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Charge un modèle CNN depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (défaut: self.symbol)
|
||||
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||
|
||||
Returns:
|
||||
True si chargement réussi
|
||||
"""
|
||||
if not CNN_AVAILABLE:
|
||||
logger.warning("CNN non disponible (PyTorch requis)")
|
||||
return False
|
||||
|
||||
sym = symbol or self.symbol
|
||||
tf = timeframe or self.config.timeframe
|
||||
try:
|
||||
self.cnn_model = CNNStrategyModel.load(sym, tf)
|
||||
logger.info(f"Modèle CNN chargé : {sym}/{tf}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.info(f"Aucun modèle CNN trouvé pour {sym}/{tf}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement modèle CNN : {e}")
|
||||
return False
|
||||
|
||||
def attach_model(self, model: 'CNNStrategyModel') -> None:
|
||||
"""Attache directement un modèle CNN (après entraînement via API)."""
|
||||
self.cnn_model = model
|
||||
self.symbol = model.symbol
|
||||
logger.info(f"Modèle CNN attaché : {model.symbol}/{model.timeframe}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Retourne True si le modèle CNN est chargé et entraîné."""
|
||||
if not CNN_AVAILABLE:
|
||||
return False
|
||||
return self.cnn_model is not None and self.cnn_model.is_trained
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Retourne les métadonnées du modèle CNN actif."""
|
||||
if not CNN_AVAILABLE:
|
||||
return {'status': 'PyTorch non disponible'}
|
||||
if not self.is_ready():
|
||||
return {'status': 'non entraîné'}
|
||||
meta = self.cnn_model.metadata.copy()
|
||||
meta['is_ready'] = True
|
||||
meta['seq_len'] = self.seq_len
|
||||
return meta
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _try_load_model(self) -> None:
|
||||
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||
try:
|
||||
self.load_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
3
src/strategies/cnn_image_driven/__init__.py
Normal file
3
src/strategies/cnn_image_driven/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.strategies.cnn_image_driven.cnn_image_strategy import CNNImageDrivenStrategy
|
||||
|
||||
__all__ = ['CNNImageDrivenStrategy']
|
||||
258
src/strategies/cnn_image_driven/cnn_image_strategy.py
Normal file
258
src/strategies/cnn_image_driven/cnn_image_strategy.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
CNN Image-Driven Strategy — Stratégie pilotée par un CNN à vision par ordinateur.
|
||||
|
||||
Contrairement au CNN 1D qui analyse des séquences numériques brutes, cette
|
||||
stratégie convertit les bougies OHLCV en images de graphiques (rendu mplfinance
|
||||
128×128 RGB), puis utilise un réseau Conv2D pour reconnaître les patterns
|
||||
visuels — exactement comme un trader devant TradingView.
|
||||
|
||||
Patterns détectés sans feature engineering explicite :
|
||||
- Bougies marteau, étoile filante, doji
|
||||
- Double top, double bottom
|
||||
- Rebonds sur supports/résistances visuels
|
||||
- Momentum : grande bougie pleine après consolidation
|
||||
- Divergences prix/volume visibles
|
||||
|
||||
Fonctionnement :
|
||||
1. Le modèle CNN-Image est chargé depuis le disque (entraîné via POST /trading/train-cnn-image)
|
||||
2. À chaque barre, les seq_len dernières bougies sont rendues en image 128×128
|
||||
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
|
||||
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||
|
||||
Intégration :
|
||||
- Compatible avec StrategyEngine (même interface que CNNDrivenStrategy)
|
||||
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||
- Requiert PyTorch + mplfinance + Pillow
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||
|
||||
try:
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
CNN_IMAGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CNN_IMAGE_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CNNImageDrivenStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading pilotée par un CNN Conv2D pré-entraîné sur images de graphiques.
|
||||
|
||||
Le modèle apprend les patterns visuels des chandeliers directement depuis
|
||||
des images 128×128 RGB sans features pré-calculées :
|
||||
- Double bottom / double top
|
||||
- Configurations chandeliers (marteau, doji, étoile filante)
|
||||
- Rebonds sur zones de support/résistance visibles
|
||||
- Structures de momentum multi-barres
|
||||
- Divergences volume/prix visuelles
|
||||
|
||||
Args:
|
||||
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||
|
||||
Config keys supplémentaires (optionnelles) :
|
||||
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||
seq_len: Nombre de bougies par image (défaut: 64)
|
||||
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'cnn_image_driven'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.min_confidence = config.get('min_confidence', 0.55)
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
self.seq_len = config.get('seq_len', 64)
|
||||
|
||||
self.cnn_image_model: Optional['CNNImageStrategyModel'] = None
|
||||
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
|
||||
return
|
||||
|
||||
# Tentative de chargement automatique du modèle existant
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_model()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# -------------------------------------------------------------------------
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal de trading via le modèle CNN Image.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Signal si le modèle est confiant, None sinon
|
||||
"""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
logger.debug("CNN Image Strategy : PyTorch/mplfinance non disponible, aucun signal")
|
||||
return None
|
||||
|
||||
if self.cnn_image_model is None or not self.cnn_image_model.is_trained:
|
||||
logger.debug("CNN Image Strategy : modèle non chargé, aucun signal")
|
||||
return None
|
||||
|
||||
if len(market_data) < self.seq_len:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.cnn_image_model.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN Image Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol = self.symbol,
|
||||
direction = direction,
|
||||
entry_price = last_close,
|
||||
stop_loss = stop_loss,
|
||||
take_profit = take_profit,
|
||||
confidence = confidence,
|
||||
timestamp = datetime.now(timezone.utc),
|
||||
strategy = self.STRATEGY_NAME,
|
||||
metadata = {
|
||||
'probas': result.get('probas', {}),
|
||||
'seq_len': self.seq_len,
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"CNN Image Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — le CNN Image travaille sur des rendus d'images."""
|
||||
return data
|
||||
|
||||
def update_params(self, params: Dict) -> None:
|
||||
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||
if 'min_confidence' in params:
|
||||
self.min_confidence = params['min_confidence']
|
||||
if self.cnn_image_model:
|
||||
self.cnn_image_model.min_confidence = params['min_confidence']
|
||||
if 'tp_atr_mult' in params:
|
||||
self.tp_atr_mult = params['tp_atr_mult']
|
||||
if 'sl_atr_mult' in params:
|
||||
self.sl_atr_mult = params['sl_atr_mult']
|
||||
if 'seq_len' in params:
|
||||
self.seq_len = params['seq_len']
|
||||
logger.info(f"CNN Image Strategy params mis à jour : {params}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Gestion du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Charge un modèle CNN Image depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (défaut: self.symbol)
|
||||
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||
|
||||
Returns:
|
||||
True si chargement réussi
|
||||
"""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
|
||||
return False
|
||||
|
||||
sym = symbol or self.symbol
|
||||
tf = timeframe or self.config.timeframe
|
||||
try:
|
||||
self.cnn_image_model = CNNImageStrategyModel.load(sym, tf)
|
||||
logger.info(f"Modèle CNN Image chargé : {sym}/{tf}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.info(f"Aucun modèle CNN Image trouvé pour {sym}/{tf}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement modèle CNN Image : {e}")
|
||||
return False
|
||||
|
||||
def attach_model(self, model: 'CNNImageStrategyModel') -> None:
|
||||
"""Attache directement un modèle CNN Image (après entraînement via API)."""
|
||||
self.cnn_image_model = model
|
||||
self.symbol = model.symbol
|
||||
logger.info(f"Modèle CNN Image attaché : {model.symbol}/{model.timeframe}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Retourne True si le modèle CNN Image est chargé et entraîné."""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
return False
|
||||
return self.cnn_image_model is not None and self.cnn_image_model.is_trained
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Retourne les métadonnées du modèle CNN Image actif."""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
return {'status': 'PyTorch/mplfinance non disponible'}
|
||||
if not self.is_ready():
|
||||
return {'status': 'non entraîné'}
|
||||
meta = self.cnn_image_model.metadata.copy()
|
||||
meta['is_ready'] = True
|
||||
meta['seq_len'] = self.seq_len
|
||||
return meta
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _try_load_model(self) -> None:
|
||||
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||
try:
|
||||
self.load_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
3
src/strategies/ensemble/__init__.py
Normal file
3
src/strategies/ensemble/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ensemble_strategy import EnsembleStrategy
|
||||
|
||||
__all__ = ['EnsembleStrategy']
|
||||
190
src/strategies/ensemble/ensemble_strategy.py
Normal file
190
src/strategies/ensemble/ensemble_strategy.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Ensemble Strategy — Stratégie combinant XGBoost + CNN (+ RL futur).
|
||||
|
||||
Cette stratégie utilise l'EnsembleModel pour agréger les signaux de plusieurs
|
||||
modèles ML. Un signal n'est émis que si les modèles sont en accord et que
|
||||
le score pondéré dépasse le seuil configuré.
|
||||
|
||||
Config keys :
|
||||
weights: dict poids par modèle (défaut: XGB=0.40, CNN=0.60)
|
||||
min_confidence: seuil score pondéré (défaut: 0.60)
|
||||
require_agreement: exiger accord entre modèles (défaut: True)
|
||||
tp_atr_mult: TP en multiples d'ATR (défaut: 2.0)
|
||||
sl_atr_mult: SL en multiples d'ATR (défaut: 1.0)
|
||||
auto_load: charger modèles existants au démarrage (défaut: True)
|
||||
symbol: paire tradée (défaut: 'EURUSD')
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal
|
||||
from src.ml.ensemble.ensemble_model import EnsembleModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnsembleStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading combinant plusieurs modèles ML via EnsembleModel.
|
||||
|
||||
Nécessite au minimum 2 modèles entraînés et attachés pour émettre
|
||||
des signaux. Les SL/TP sont calculés à partir de l'ATR.
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'ensemble'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
# Forcer le nom de la stratégie
|
||||
config.setdefault('name', self.STRATEGY_NAME)
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
|
||||
self.ensemble = EnsembleModel(
|
||||
weights=config.get('weights'),
|
||||
min_confidence=config.get('min_confidence', 0.60),
|
||||
require_agreement=config.get('require_agreement', True),
|
||||
)
|
||||
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_models()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# ------------------------------------------------------------------
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal via l'ensemble de modèles ML.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum 50 barres)
|
||||
|
||||
Returns:
|
||||
Signal si l'ensemble est confiant et en accord, None sinon
|
||||
"""
|
||||
if not self.ensemble.is_ready():
|
||||
logger.debug("Ensemble Strategy : ensemble non prêt (< 2 modèles)")
|
||||
return None
|
||||
|
||||
if len(market_data) < 50:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.ensemble.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ensemble Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol=self.symbol,
|
||||
direction=direction,
|
||||
entry_price=last_close,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=take_profit,
|
||||
confidence=confidence,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
strategy=self.STRATEGY_NAME,
|
||||
metadata={
|
||||
'ensemble_agreement': result.get('agreement', False),
|
||||
'components': result.get('components', {}),
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Ensemble Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%} | accord={result.get('agreement')}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — les features sont dans predict()."""
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Gestion des modèles
|
||||
# ------------------------------------------------------------------
|
||||
def attach_xgboost(self, model) -> None:
|
||||
"""Attache un modèle XGBoost à l'ensemble."""
|
||||
self.ensemble.attach_xgboost(model)
|
||||
|
||||
def attach_cnn(self, model) -> None:
|
||||
"""Attache un modèle CNN à l'ensemble."""
|
||||
self.ensemble.attach_cnn(model)
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""True si l'ensemble a au moins 2 modèles entraînés."""
|
||||
return self.ensemble.is_ready()
|
||||
|
||||
def get_status(self) -> Dict:
|
||||
"""Retourne le statut de l'ensemble."""
|
||||
return self.ensemble.get_status()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chargement automatique
|
||||
# ------------------------------------------------------------------
|
||||
def _try_load_models(self) -> None:
|
||||
"""Tente de charger les modèles existants au démarrage."""
|
||||
# XGBoost via MLStrategyModel
|
||||
try:
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
xgb = MLStrategyModel.load(self.symbol, self.config.timeframe, 'xgboost')
|
||||
self.ensemble.attach_xgboost(xgb)
|
||||
logger.info(f"Ensemble : modèle XGBoost chargé pour {self.symbol}")
|
||||
except Exception:
|
||||
logger.debug(f"Ensemble : pas de modèle XGBoost pour {self.symbol}")
|
||||
|
||||
# CNN via CNNStrategyModel (peut ne pas encore exister)
|
||||
try:
|
||||
from src.ml.cnn import CNNStrategyModel
|
||||
cnn = CNNStrategyModel.load(self.symbol, self.config.timeframe)
|
||||
self.ensemble.attach_cnn(cnn)
|
||||
logger.info(f"Ensemble : modèle CNN chargé pour {self.symbol}")
|
||||
except Exception:
|
||||
logger.debug(f"Ensemble : pas de modèle CNN pour {self.symbol}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
10
src/strategies/rl_driven/__init__.py
Normal file
10
src/strategies/rl_driven/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Module RL-Driven Strategy — Stratégie de trading pilotée par un agent PPO.
|
||||
|
||||
L'agent PPO apprend à trader par renforcement : il maximise directement
|
||||
le PnL sans supervision explicite (pas de labels LONG/SHORT/NEUTRAL).
|
||||
"""
|
||||
|
||||
from src.strategies.rl_driven.rl_strategy import RLDrivenStrategy, RL_AVAILABLE
|
||||
|
||||
__all__ = ['RLDrivenStrategy', 'RL_AVAILABLE']
|
||||
259
src/strategies/rl_driven/rl_strategy.py
Normal file
259
src/strategies/rl_driven/rl_strategy.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
RL-Driven Strategy — Stratégie pilotée par un agent de Reinforcement Learning (PPO).
|
||||
|
||||
Contrairement aux stratégies supervisées (XGBoost, CNN) qui apprennent depuis des
|
||||
labels pré-calculés, cette stratégie utilise un agent PPO (Proximal Policy
|
||||
Optimization) entraîné via interaction directe avec un environnement de trading
|
||||
(TradingEnv). L'agent apprend à maximiser la récompense cumulative (profit ajusté
|
||||
par le risque) sans avoir besoin de labels explicites.
|
||||
|
||||
Fonctionnement :
|
||||
1. Le modèle RLStrategyModel est chargé depuis le disque
|
||||
(entraîné via POST /trading/train-rl)
|
||||
2. À chaque barre, les seq_len dernières bougies sont fournies à l'agent
|
||||
3. L'agent PPO retourne une action : LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||
avec un score de confiance basé sur les probabilités de la politique
|
||||
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||
|
||||
Intégration :
|
||||
- Compatible avec StrategyEngine (même interface que CNNImageDrivenStrategy)
|
||||
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||
- Requiert PyTorch (CPU ou CUDA)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||
|
||||
# Import conditionnel — RL_AVAILABLE reste False si PyTorch ou stable-baselines3
|
||||
# ne sont pas installés dans le container
|
||||
try:
|
||||
from src.ml.rl.rl_strategy_model import RLStrategyModel
|
||||
RL_AVAILABLE = True
|
||||
except ImportError:
|
||||
RLStrategyModel = None
|
||||
RL_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RLDrivenStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading pilotée par un agent PPO (Reinforcement Learning).
|
||||
|
||||
L'agent apprend à maximiser le profit cumulatif en interagissant avec un
|
||||
environnement de trading simulé (TradingEnv). Aucun label supervisé n'est
|
||||
nécessaire : la récompense est définie par le P&L réalisé et les pénalités
|
||||
de risque (drawdown, over-trading).
|
||||
|
||||
Args:
|
||||
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||
|
||||
Config keys supplémentaires (optionnelles) :
|
||||
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||
seq_len: Fenêtre d'observation RL (barres) (défaut: 20)
|
||||
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'rl_driven'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.min_confidence = config.get('min_confidence', 0.55)
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
self.seq_len = config.get('seq_len', 20)
|
||||
|
||||
self.rl_model: Optional['RLStrategyModel'] = None
|
||||
|
||||
if not RL_AVAILABLE:
|
||||
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
|
||||
return
|
||||
|
||||
# Tentative de chargement automatique du modèle existant
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_model()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal de trading via l'agent PPO.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Signal si l'agent est confiant, None sinon
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
logger.debug("RL Strategy : PyTorch / stable-baselines3 non disponible, aucun signal")
|
||||
return None
|
||||
|
||||
if self.rl_model is None or not self.rl_model.is_trained:
|
||||
logger.debug("RL Strategy : modèle non chargé, aucun signal")
|
||||
return None
|
||||
|
||||
if len(market_data) < self.seq_len:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.rl_model.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"RL Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour le calcul des niveaux SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol = self.symbol,
|
||||
direction = direction,
|
||||
entry_price = last_close,
|
||||
stop_loss = stop_loss,
|
||||
take_profit = take_profit,
|
||||
confidence = confidence,
|
||||
timestamp = datetime.now(timezone.utc),
|
||||
strategy = self.STRATEGY_NAME,
|
||||
metadata = {
|
||||
'probas': result.get('probas', {}),
|
||||
'seq_len': self.seq_len,
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'avg_reward': result.get('avg_reward'),
|
||||
'total_timesteps': result.get('total_timesteps'),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"RL Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — l'agent RL travaille sur les OHLCV bruts."""
|
||||
return data
|
||||
|
||||
def update_params(self, params: Dict) -> None:
|
||||
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||
if 'min_confidence' in params:
|
||||
self.min_confidence = params['min_confidence']
|
||||
if self.rl_model:
|
||||
self.rl_model.min_confidence = params['min_confidence']
|
||||
if 'tp_atr_mult' in params:
|
||||
self.tp_atr_mult = params['tp_atr_mult']
|
||||
if 'sl_atr_mult' in params:
|
||||
self.sl_atr_mult = params['sl_atr_mult']
|
||||
if 'seq_len' in params:
|
||||
self.seq_len = params['seq_len']
|
||||
logger.info(f"RL Strategy params mis à jour : {params}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Gestion du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Charge un modèle RL depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (défaut: self.symbol)
|
||||
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||
|
||||
Returns:
|
||||
True si chargement réussi
|
||||
"""
|
||||
if not RL_AVAILABLE:
|
||||
logger.warning("RL Strategy non disponible (PyTorch / stable-baselines3 requis)")
|
||||
return False
|
||||
|
||||
sym = symbol or self.symbol
|
||||
tf = timeframe or self.config.timeframe
|
||||
try:
|
||||
self.rl_model = RLStrategyModel.load(sym, tf)
|
||||
logger.info(f"Modèle RL chargé : {sym}/{tf}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.info(f"Aucun modèle RL trouvé pour {sym}/{tf}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement modèle RL : {e}")
|
||||
return False
|
||||
|
||||
def attach_model(self, model: 'RLStrategyModel') -> None:
|
||||
"""Attache directement un modèle RL (après entraînement via API)."""
|
||||
self.rl_model = model
|
||||
self.symbol = model.symbol
|
||||
logger.info(f"Modèle RL attaché : {model.symbol}/{model.timeframe}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Retourne True si le modèle RL est chargé et entraîné."""
|
||||
if not RL_AVAILABLE:
|
||||
return False
|
||||
return self.rl_model is not None and self.rl_model.is_trained
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Retourne les métadonnées du modèle RL actif."""
|
||||
if not RL_AVAILABLE:
|
||||
return {'status': 'PyTorch / stable-baselines3 non disponible'}
|
||||
if not self.is_ready():
|
||||
return {'status': 'non entraîné'}
|
||||
meta = self.rl_model.metadata.copy()
|
||||
meta['is_ready'] = True
|
||||
meta['seq_len'] = self.seq_len
|
||||
return meta
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _try_load_model(self) -> None:
|
||||
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||
try:
|
||||
self.load_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
Reference in New Issue
Block a user