feat: ML-Driven Strategy — apprentissage des patterns TA humains
Nouveau module complet pour entraîner un modèle XGBoost/LightGBM
qui apprend à détecter des opportunités depuis des indicateurs classiques :
RSI (divergences), MACD (crossovers), Bollinger (squeeze/rebond),
Supports/Résistances (pivots locaux), Points Pivots (classiques + Fibonacci),
patterns chandeliers (marteau, engulfing), alignement EMAs, volume.
Fichiers créés :
- src/ml/features/technical_features.py (~50 features TA)
- src/ml/features/label_generator.py (labels LONG/SHORT/NEUTRAL par forward simulation ATR)
- src/ml/ml_strategy_model.py (entraînement + walk-forward + sauvegarde joblib)
- src/strategies/ml_driven/ml_strategy.py (stratégie compatible StrategyEngine)
Routes API ajoutées :
- POST /trading/train (entraînement async)
- GET /trading/train/{job_id} (état du job)
- GET /trading/ml-models (liste modèles disponibles)
- GET /trading/ml-models/{symbol}/{tf}/importance (feature importance)
Documentation : docs/ML_STRATEGY_GUIDE.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
306
docs/ML_STRATEGY_GUIDE.md
Normal file
306
docs/ML_STRATEGY_GUIDE.md
Normal file
@@ -0,0 +1,306 @@
|
||||
# ML-Driven Strategy — Guide Complet
|
||||
|
||||
## Concept
|
||||
|
||||
La stratégie ML-Driven remplace les règles codées en dur par un modèle
|
||||
d'apprentissage supervisé (XGBoost/LightGBM) qui apprend à reconnaître
|
||||
les patterns utilisés par les traders humains :
|
||||
|
||||
- Rebonds sur supports / résistances
|
||||
- Divergences RSI / MACD
|
||||
- Squeeze Bollinger + expansion
|
||||
- Alignement EMAs (tendance)
|
||||
- Patterns de chandeliers (marteau, engulfing, étoile filante...)
|
||||
- Proximité des pivots classiques et Fibonacci
|
||||
|
||||
Le modèle apprend **quelles combinaisons** de ces signaux sont réellement
|
||||
prédictives sur les données historiques, ce qu'un trader expérimenté ferait
|
||||
intuitivement après des années de pratique.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
src/
|
||||
├── ml/
|
||||
│ ├── features/
|
||||
│ │ ├── technical_features.py # Calcul de ~50 features TA
|
||||
│ │ └── label_generator.py # Labels LONG/SHORT/NEUTRAL (forward simulation)
|
||||
│ └── ml_strategy_model.py # Entraînement XGBoost + sauvegarde
|
||||
└── strategies/
|
||||
└── ml_driven/
|
||||
└── ml_strategy.py # Stratégie compatible StrategyEngine
|
||||
```
|
||||
|
||||
### Pipeline de données
|
||||
|
||||
```
|
||||
Données OHLCV (2 ans, 1h)
|
||||
↓
|
||||
TechnicalFeatureBuilder (~50 features par barre)
|
||||
↓
|
||||
LabelGenerator (forward simulation : TP/SL atteint dans N barres ?)
|
||||
↓
|
||||
XGBoost (entraînement supervisé + walk-forward validation)
|
||||
↓
|
||||
MLStrategyModel sauvegardé sur disque (models/ml_strategy/)
|
||||
↓
|
||||
MLDrivenStrategy.analyze() → Signal(LONG/SHORT) si confidence >= seuil
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Features calculées
|
||||
|
||||
### RSI
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `rsi` | Valeur brute RSI(14) |
|
||||
| `rsi_oversold` | RSI < 30 (zone survente) |
|
||||
| `rsi_overbought` | RSI > 70 (zone surachat) |
|
||||
| `rsi_slope` | Pente RSI sur 3 barres |
|
||||
| `rsi_bullish_div` | Divergence haussière (prix LL, RSI HL) |
|
||||
| `rsi_bearish_div` | Divergence baissière (prix HH, RSI LH) |
|
||||
|
||||
### MACD
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `macd` | Ligne MACD (EMA12 - EMA26) |
|
||||
| `macd_signal` | Ligne signal (EMA9 du MACD) |
|
||||
| `macd_hist` | Histogramme |
|
||||
| `macd_hist_slope` | Pente histogramme (momentum) |
|
||||
| `macd_cross_up` | Crossover haussier |
|
||||
| `macd_cross_down` | Crossover baissier |
|
||||
|
||||
### Bollinger Bands
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `bb_position` | Position relative dans les bandes [0..1] |
|
||||
| `bb_bandwidth` | Largeur normalisée (indicateur de volatilité) |
|
||||
| `bb_squeeze` | Compression < percentile 20 (signal explosion) |
|
||||
| `bb_break_up/down` | Cassure de bande |
|
||||
| `bb_bounce_low/high` | Rebond depuis bande inf/sup |
|
||||
|
||||
### Supports / Résistances
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `dist_to_resistance` | Distance en ATR à la résistance la plus proche |
|
||||
| `dist_to_support` | Distance en ATR au support le plus proche |
|
||||
| `bounce_from_support` | Prix à < 1 ATR d'un support |
|
||||
| `rejection_at_resistance` | Prix à < 1 ATR d'une résistance |
|
||||
|
||||
### Points Pivots
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `dist_pivot` | Distance au pivot central |
|
||||
| `dist_r1/s1/r2/s2` | Distance aux niveaux classiques |
|
||||
| `dist_r1f/s1f` | Distance aux niveaux Fibonacci |
|
||||
| `near_pivot/r1/s1` | Prix dans une zone de 0.5 ATR |
|
||||
|
||||
### Chandeliers
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `hammer` | Marteau (rebond potentiel) |
|
||||
| `shooting_star` | Étoile filante (rejet potentiel) |
|
||||
| `bullish_engulfing` | Engulfing haussier |
|
||||
| `bearish_engulfing` | Engulfing baissier |
|
||||
| `doji` | Doji (indécision) |
|
||||
| `body_ratio` | Corps / mèche totale |
|
||||
|
||||
### Tendance / EMA
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `trend_bull` | EMA8 > EMA21 > EMA50 |
|
||||
| `trend_bear` | EMA8 < EMA21 < EMA50 |
|
||||
| `ema21_slope` | Pente EMA21 sur 5 barres |
|
||||
| `above_ema200` | Prix au-dessus de l'EMA200 |
|
||||
| `ema_X_dist` | Distance % du prix à chaque EMA |
|
||||
|
||||
### Temporel
|
||||
| Feature | Description |
|
||||
|---|---|
|
||||
| `is_london` | Session Londres (8h-16h UTC) |
|
||||
| `is_ny` | Session New York (13h-21h UTC) |
|
||||
| `is_overlap` | Chevauchement London+NY (13h-16h) |
|
||||
| `day_of_week` | Jour de la semaine |
|
||||
|
||||
---
|
||||
|
||||
## Génération des Labels
|
||||
|
||||
### Méthode ATR-based (recommandée)
|
||||
Pour chaque barre i, on simule un trade dans les `horizon` barres suivantes :
|
||||
- **LONG (1)** : le HIGH atteint `entry + tp_atr_mult × ATR` avant que le LOW descende sous `entry - sl_atr_mult × ATR`
|
||||
- **SHORT (-1)** : le LOW atteint `entry - tp_atr_mult × ATR` avant que le HIGH monte au-dessus de `entry + sl_atr_mult × ATR`
|
||||
- **NEUTRAL (0)** : ni TP ni SL atteint dans l'horizon
|
||||
|
||||
Paramètres par défaut : `tp_atr_mult=2.0`, `sl_atr_mult=1.0` → R:R = 2:1
|
||||
|
||||
### Méthode pourcentage fixe
|
||||
Même logique avec des seuils en % : `tp_pct=0.003` (0.3%), `sl_pct=0.002` (0.2%).
|
||||
|
||||
---
|
||||
|
||||
## Validation
|
||||
|
||||
Walk-forward cross-validation (3 folds temporels) :
|
||||
- Fold 1 : entraîné sur 0..33%, testé sur 33..67%
|
||||
- Fold 2 : entraîné sur 0..67%, testé sur 67..100%
|
||||
- etc.
|
||||
|
||||
Métriques retournées :
|
||||
- `wf_accuracy` : accuracy moyenne inter-folds
|
||||
- `wf_precision` : précision sur signaux directionnels
|
||||
- `wf_recall` : recall sur signaux directionnels
|
||||
- `label_dist` : distribution LONG/SHORT/NEUTRAL
|
||||
|
||||
---
|
||||
|
||||
## Routes API
|
||||
|
||||
### `POST /trading/train`
|
||||
Lance l'entraînement en arrière-plan.
|
||||
|
||||
**Body :**
|
||||
```json
|
||||
{
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h",
|
||||
"period": "2y",
|
||||
"model_type": "xgboost",
|
||||
"tp_atr_mult": 2.0,
|
||||
"sl_atr_mult": 1.0,
|
||||
"horizon": 30,
|
||||
"min_confidence": 0.55
|
||||
}
|
||||
```
|
||||
|
||||
**Réponse :**
|
||||
```json
|
||||
{
|
||||
"job_id": "uuid",
|
||||
"status": "pending",
|
||||
"symbol": "EURUSD",
|
||||
"timeframe": "1h"
|
||||
}
|
||||
```
|
||||
|
||||
### `GET /trading/train/{job_id}`
|
||||
Consulter l'état d'un entraînement.
|
||||
|
||||
**Réponse (completed) :**
|
||||
```json
|
||||
{
|
||||
"job_id": "...",
|
||||
"status": "completed",
|
||||
"n_samples": 8760,
|
||||
"n_features": 52,
|
||||
"wf_accuracy": 0.58,
|
||||
"wf_precision": 0.61,
|
||||
"label_dist": {"long": 1200, "short": 1150, "neutral": 6410},
|
||||
"trained_at": "2026-03-08T12:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
### `GET /trading/ml-models`
|
||||
Liste tous les modèles entraînés.
|
||||
|
||||
### `GET /trading/ml-models/{symbol}/{timeframe}/importance`
|
||||
Top features les plus importantes du modèle.
|
||||
|
||||
---
|
||||
|
||||
## Sauvegarde des Modèles
|
||||
|
||||
Les modèles sont sauvegardés dans `models/ml_strategy/` :
|
||||
```
|
||||
models/ml_strategy/
|
||||
├── EURUSD_1h_xgboost.joblib # Modèle + scaler + feature names
|
||||
└── EURUSD_1h_xgboost_meta.json # Métriques et configuration
|
||||
```
|
||||
|
||||
Au redémarrage, `MLDrivenStrategy` tente de charger automatiquement le modèle
|
||||
existant pour le symbole/timeframe configuré (`auto_load=True`).
|
||||
|
||||
---
|
||||
|
||||
## Utilisation Manuelle (Python)
|
||||
|
||||
```python
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
|
||||
# Entraînement
|
||||
model = MLStrategyModel(symbol='EURUSD', timeframe='1h', model_type='xgboost')
|
||||
metrics = model.train(df_ohlcv)
|
||||
print(f"WF Accuracy : {metrics['wf_metrics']['avg_accuracy']:.2%}")
|
||||
|
||||
# Prédiction
|
||||
result = model.predict(df_recent)
|
||||
print(f"Signal : {result['signal']}, Confidence : {result['confidence']:.2%}")
|
||||
# → {'signal': 1, 'confidence': 0.72, 'tradeable': True, 'probas': {...}}
|
||||
|
||||
# Chargement depuis disque
|
||||
model = MLStrategyModel.load('EURUSD', '1h', 'xgboost')
|
||||
|
||||
# Importance des features
|
||||
for f in model.get_feature_importance(top_n=10):
|
||||
print(f" {f['feature']}: {f['importance']:.4f}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Intégration avec la Stratégie
|
||||
|
||||
```python
|
||||
from src.strategies.ml_driven import MLDrivenStrategy
|
||||
|
||||
config = {
|
||||
'name': 'ml_driven',
|
||||
'symbol': 'EURUSD',
|
||||
'timeframe': '1h',
|
||||
'risk_per_trade': 0.01,
|
||||
'model_type': 'xgboost',
|
||||
'min_confidence': 0.55,
|
||||
'tp_atr_mult': 2.0,
|
||||
'sl_atr_mult': 1.0,
|
||||
'auto_load': True, # Charge automatiquement si modèle existant
|
||||
}
|
||||
strategy = MLDrivenStrategy(config)
|
||||
|
||||
# Génération d'un signal
|
||||
signal = strategy.analyze(df_ohlcv)
|
||||
if signal:
|
||||
print(f"{signal.direction} @ {signal.entry_price}, conf={signal.confidence:.2%}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Considérations
|
||||
|
||||
### Overfitting
|
||||
- La walk-forward validation (3 folds temporels) protège contre l'overfitting in-sample
|
||||
- Utiliser au minimum 1 an de données (≥ 8 000 barres en 1h)
|
||||
- Éviter les périodes trop courtes (`period < 6m`)
|
||||
|
||||
### Déséquilibre des classes
|
||||
- En général : ~15% LONG, ~15% SHORT, ~70% NEUTRAL
|
||||
- XGBoost gère naturellement le déséquilibre
|
||||
- Le seuil `min_confidence` filtre les signaux peu sûrs
|
||||
|
||||
### Re-entraînement
|
||||
- Recommandé tous les 3-6 mois pour capturer les changements de régime
|
||||
- Ou après un changement de volatilité significatif (crise, FOMC...)
|
||||
|
||||
### Métriques cibles
|
||||
- `wf_accuracy > 0.55` (mieux que le hasard)
|
||||
- `wf_precision > 0.50` sur signaux directionnels
|
||||
- Distribución LONG/SHORT relativement équilibrée
|
||||
|
||||
---
|
||||
|
||||
## Historique
|
||||
|
||||
| Date | Version | Description |
|
||||
|---|---|---|
|
||||
| 2026-03-08 | v1.0 | Création initiale — XGBoost/LightGBM + features TA classiques |
|
||||
@@ -733,3 +733,187 @@ def get_optimize_result(job_id: str):
|
||||
if job is None:
|
||||
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||
return OptimizeResponse(job_id=job_id, **job)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ML STRATEGY — Entraînement et gestion des modèles ML-Driven
|
||||
# =============================================================================
|
||||
|
||||
# Stockage en mémoire des jobs d'entraînement (même pattern que backtest/optimize)
|
||||
_train_jobs: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
"""Requête d'entraînement du modèle ML-Driven."""
|
||||
symbol: str = "EURUSD"
|
||||
timeframe: str = "1h"
|
||||
period: str = "2y" # Période de données historiques
|
||||
model_type: str = "xgboost" # xgboost | lightgbm | random_forest
|
||||
tp_atr_mult: float = 2.0 # Multiplicateur ATR pour TP (labels)
|
||||
sl_atr_mult: float = 1.0 # Multiplicateur ATR pour SL (labels)
|
||||
horizon: int = 30 # Barres pour évaluer TP/SL (labels)
|
||||
min_confidence: float = 0.55 # Seuil confiance minimum pour signaux
|
||||
|
||||
|
||||
class TrainResponse(BaseModel):
|
||||
"""Réponse d'un job d'entraînement."""
|
||||
job_id: str
|
||||
status: str # pending | running | completed | failed
|
||||
symbol: Optional[str] = None
|
||||
timeframe: Optional[str] = None
|
||||
model_type: Optional[str] = None
|
||||
n_samples: Optional[int] = None
|
||||
n_features: Optional[int] = None
|
||||
wf_accuracy: Optional[float] = None # Walk-forward accuracy moyenne
|
||||
wf_precision: Optional[float] = None
|
||||
label_dist: Optional[dict] = None # Distribution LONG/SHORT/NEUTRAL
|
||||
trained_at: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
async def _run_train_task(job_id: str, request: TrainRequest) -> None:
|
||||
"""Tâche d'entraînement exécutée en arrière-plan."""
|
||||
_train_jobs[job_id]["status"] = "running"
|
||||
try:
|
||||
# Récupération des données historiques
|
||||
data_service = _get_data_service()
|
||||
df = await data_service.get_historical_data(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
period = request.period,
|
||||
)
|
||||
if df is None or len(df) < 200:
|
||||
raise ValueError(f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)")
|
||||
|
||||
# Entraînement dans un thread (opération CPU-bound)
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, _sync_train, df, request)
|
||||
|
||||
_train_jobs[job_id].update({
|
||||
"status": "completed",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
"model_type": request.model_type,
|
||||
"n_samples": result.get("n_samples"),
|
||||
"n_features": result.get("n_features"),
|
||||
"wf_accuracy": result.get("wf_metrics", {}).get("avg_accuracy"),
|
||||
"wf_precision": result.get("wf_metrics", {}).get("avg_precision"),
|
||||
"label_dist": result.get("label_dist"),
|
||||
"trained_at": result.get("trained_at"),
|
||||
})
|
||||
|
||||
# Auto-attachement à la stratégie ML active si elle existe
|
||||
_attach_model_to_strategy(request)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur entraînement ML job {job_id} : {exc}", exc_info=True)
|
||||
_train_jobs[job_id]["status"] = "failed"
|
||||
_train_jobs[job_id]["error"] = str(exc)
|
||||
|
||||
|
||||
def _sync_train(df, request: TrainRequest) -> dict:
|
||||
"""Wrapper synchrone pour MLStrategyModel.train() (exécuté dans un thread)."""
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
model = MLStrategyModel(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
model_type = request.model_type,
|
||||
tp_atr_mult = request.tp_atr_mult,
|
||||
sl_atr_mult = request.sl_atr_mult,
|
||||
horizon = request.horizon,
|
||||
min_confidence = request.min_confidence,
|
||||
)
|
||||
return model.train(df)
|
||||
|
||||
|
||||
def _attach_model_to_strategy(request: TrainRequest) -> None:
|
||||
"""Attache le modèle entraîné à la stratégie ML-driven active (paper trading)."""
|
||||
try:
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
from src.strategies.ml_driven import MLDrivenStrategy
|
||||
|
||||
engine = _paper_state.get("engine")
|
||||
if engine and hasattr(engine, 'strategy_engine'):
|
||||
strat = engine.strategy_engine.strategies.get('ml_driven')
|
||||
if strat and isinstance(strat, MLDrivenStrategy):
|
||||
model = MLStrategyModel.load(request.symbol, request.timeframe, request.model_type)
|
||||
strat.attach_model(model)
|
||||
logger.info("Modèle ML attaché à la stratégie ml_driven active")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auto-attach modèle ignoré : {e}")
|
||||
|
||||
|
||||
@router.post("/train", response_model=TrainResponse, summary="Entraîner le modèle ML-Driven")
|
||||
async def start_train(request: TrainRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Lance l'entraînement du modèle ML-Driven en arrière-plan.
|
||||
|
||||
Le modèle apprend à détecter des opportunités de trading à partir
|
||||
d'indicateurs techniques classiques (RSI, MACD, S/R, pivots, chandeliers...).
|
||||
|
||||
- Retourne un `job_id` à interroger via `GET /trading/train/{job_id}`
|
||||
- Le modèle est sauvegardé sur disque après entraînement
|
||||
- Si un paper trading ML-Driven est actif, le modèle lui est automatiquement attaché
|
||||
"""
|
||||
if request.model_type not in ("xgboost", "lightgbm", "random_forest"):
|
||||
raise HTTPException(400, detail="model_type doit être : xgboost | lightgbm | random_forest")
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
_train_jobs[job_id] = {
|
||||
"status": "pending",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
"model_type": request.model_type,
|
||||
}
|
||||
|
||||
background_tasks.add_task(_run_train_task, job_id, request)
|
||||
|
||||
return TrainResponse(
|
||||
job_id = job_id,
|
||||
status = "pending",
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
model_type = request.model_type,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/train/{job_id}", response_model=TrainResponse, summary="Résultat entraînement ML")
|
||||
def get_train_result(job_id: str):
|
||||
"""Retourne l'état d'un job d'entraînement ML-Driven."""
|
||||
job = _train_jobs.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||
return TrainResponse(job_id=job_id, **job)
|
||||
|
||||
|
||||
@router.get("/ml-models", summary="Liste des modèles ML-Driven entraînés")
|
||||
def list_ml_models():
|
||||
"""
|
||||
Retourne la liste de tous les modèles ML-Driven disponibles sur disque,
|
||||
avec leurs métriques (accuracy, date d'entraînement, nombre de samples...).
|
||||
"""
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
models = MLStrategyModel.list_trained_models()
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
|
||||
@router.get("/ml-models/{symbol}/{timeframe}/importance", summary="Importance des features ML")
|
||||
def get_feature_importance(symbol: str, timeframe: str, model_type: str = "xgboost", top_n: int = 20):
|
||||
"""
|
||||
Retourne les N features les plus importantes du modèle ML-Driven.
|
||||
Utile pour comprendre quels indicateurs le modèle utilise le plus.
|
||||
"""
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
try:
|
||||
model = MLStrategyModel.load(symbol, timeframe, model_type)
|
||||
importance = model.get_feature_importance(top_n=top_n)
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"model_type": model_type,
|
||||
"importance": importance,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(404, detail=f"Modèle non trouvé pour {symbol}/{timeframe}/{model_type}")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, detail=str(e))
|
||||
|
||||
0
src/ml/features/__init__.py
Normal file
0
src/ml/features/__init__.py
Normal file
180
src/ml/features/label_generator.py
Normal file
180
src/ml/features/label_generator.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Générateur de Labels — Supervision du modèle ML stratégie.
|
||||
|
||||
Pour chaque barre, on regarde ce qui se passe dans les N barres suivantes :
|
||||
- LONG (1) : le prix monte de `tp_pct` avant de baisser de `sl_pct`
|
||||
- SHORT (-1) : le prix baisse de `sl_pct` avant de monter de `tp_pct`
|
||||
- NEUTRAL (0) : aucune des deux conditions n'est atteinte
|
||||
|
||||
Cette approche "forward simulation" reproduit la logique réelle d'un trade :
|
||||
chaque label correspond à un trade qui aurait été rentable (TP atteint avant SL).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LabelGenerator:
|
||||
"""
|
||||
Génère des labels de classification pour l'entraînement supervisé.
|
||||
|
||||
Args:
|
||||
tp_pct: Mouvement haussier minimal pour valider un LONG (ex: 0.003 = 0.3%)
|
||||
sl_pct: Mouvement baissier maximal avant stop-loss (ex: 0.002 = 0.2%)
|
||||
horizon: Nombre de barres max pour atteindre TP ou SL
|
||||
min_ratio: Ratio TP/SL minimum (ignore les configs déséquilibrées)
|
||||
|
||||
Exemple avec ATR dynamique:
|
||||
gen = LabelGenerator(tp_pct=None, sl_pct=None)
|
||||
labels = gen.generate_atr(df, atr_tp_mult=2.0, atr_sl_mult=1.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_pct: float = 0.003, # 0.3% TP par défaut
|
||||
sl_pct: float = 0.002, # 0.2% SL par défaut → R:R = 1.5
|
||||
horizon: int = 30, # Fenêtre de 30 barres
|
||||
min_ratio: float = 1.0, # TP doit être >= SL
|
||||
):
|
||||
self.tp_pct = tp_pct
|
||||
self.sl_pct = sl_pct
|
||||
self.horizon = horizon
|
||||
self.min_ratio = min_ratio
|
||||
|
||||
def generate(self, df: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
Génère les labels à partir de seuils fixes en pourcentage.
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV (colonnes en minuscules)
|
||||
|
||||
Returns:
|
||||
pd.Series avec valeurs 1 (LONG), -1 (SHORT), 0 (NEUTRAL)
|
||||
"""
|
||||
df = df.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
labels = self._forward_simulate(df, self.tp_pct, self.sl_pct)
|
||||
self._log_distribution(labels)
|
||||
return labels
|
||||
|
||||
def generate_atr_based(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
atr_period: int = 14,
|
||||
atr_tp_mult: float = 2.0,
|
||||
atr_sl_mult: float = 1.0,
|
||||
) -> pd.Series:
|
||||
"""
|
||||
Génère les labels avec des seuils TP/SL dynamiques basés sur l'ATR.
|
||||
Plus adapté aux marchés forex où la volatilité varie beaucoup.
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV
|
||||
atr_period: Période ATR
|
||||
atr_tp_mult: Multiplicateur ATR pour TP (ex: 2.0 = 2×ATR)
|
||||
atr_sl_mult: Multiplicateur ATR pour SL (ex: 1.0 = 1×ATR)
|
||||
|
||||
Returns:
|
||||
pd.Series labels 1 / -1 / 0
|
||||
"""
|
||||
df = df.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
# Calcul ATR
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(atr_period).mean()
|
||||
|
||||
labels = np.zeros(len(df), dtype=int)
|
||||
|
||||
for i in range(len(df) - self.horizon):
|
||||
close = df['close'].iloc[i]
|
||||
atr_i = atr.iloc[i]
|
||||
if np.isnan(atr_i) or atr_i <= 0:
|
||||
continue
|
||||
|
||||
tp_long = close + atr_tp_mult * atr_i
|
||||
sl_long = close - atr_sl_mult * atr_i
|
||||
tp_short = close - atr_tp_mult * atr_i
|
||||
sl_short = close + atr_sl_mult * atr_i
|
||||
|
||||
future = df.iloc[i + 1: i + 1 + self.horizon]
|
||||
label = self._classify_bar(future, tp_long, sl_long, tp_short, sl_short)
|
||||
labels[i] = label
|
||||
|
||||
result = pd.Series(labels, index=df.index)
|
||||
self._log_distribution(result)
|
||||
return result
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internals
|
||||
# -------------------------------------------------------------------------
|
||||
def _forward_simulate(
|
||||
self, df: pd.DataFrame, tp_pct: float, sl_pct: float
|
||||
) -> pd.Series:
|
||||
"""Simule chaque barre : TP ou SL atteint en premier dans l'horizon."""
|
||||
labels = np.zeros(len(df), dtype=int)
|
||||
|
||||
for i in range(len(df) - self.horizon):
|
||||
close = df['close'].iloc[i]
|
||||
tp_long = close * (1 + tp_pct)
|
||||
sl_long = close * (1 - sl_pct)
|
||||
tp_short = close * (1 - tp_pct)
|
||||
sl_short = close * (1 + sl_pct)
|
||||
|
||||
future = df.iloc[i + 1: i + 1 + self.horizon]
|
||||
labels[i] = self._classify_bar(future, tp_long, sl_long, tp_short, sl_short)
|
||||
|
||||
return pd.Series(labels, index=df.index)
|
||||
|
||||
@staticmethod
|
||||
def _classify_bar(
|
||||
future: pd.DataFrame,
|
||||
tp_long: float,
|
||||
sl_long: float,
|
||||
tp_short: float,
|
||||
sl_short: float,
|
||||
) -> int:
|
||||
"""
|
||||
Parcourt les barres futures bar par bar et retourne le label.
|
||||
Vérifie HIGH pour TP LONG et LOW pour SL LONG (et inversement pour SHORT).
|
||||
"""
|
||||
for _, bar in future.iterrows():
|
||||
# LONG : TP atteint ?
|
||||
if bar['high'] >= tp_long and bar['low'] > sl_long:
|
||||
return 1
|
||||
# LONG : SL atteint en premier ?
|
||||
if bar['low'] <= sl_long:
|
||||
# Vérifie si TP atteint le même bar (candle ambiguë)
|
||||
if bar['high'] >= tp_long:
|
||||
return 0 # Ambigu → neutre
|
||||
return 0 # SL touché → pas de LONG
|
||||
|
||||
# SHORT : TP atteint ?
|
||||
if bar['low'] <= tp_short and bar['high'] < sl_short:
|
||||
return -1
|
||||
# SHORT : SL atteint en premier ?
|
||||
if bar['high'] >= sl_short:
|
||||
if bar['low'] <= tp_short:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
return 0 # Ni TP ni SL atteint dans l'horizon
|
||||
|
||||
@staticmethod
|
||||
def _log_distribution(labels: pd.Series) -> None:
|
||||
total = len(labels)
|
||||
if total == 0:
|
||||
return
|
||||
n_long = (labels == 1).sum()
|
||||
n_short = (labels == -1).sum()
|
||||
n_neutral = (labels == 0).sum()
|
||||
logger.info(
|
||||
f"Distribution labels : LONG={n_long} ({n_long/total:.1%}), "
|
||||
f"SHORT={n_short} ({n_short/total:.1%}), "
|
||||
f"NEUTRAL={n_neutral} ({n_neutral/total:.1%})"
|
||||
)
|
||||
373
src/ml/features/technical_features.py
Normal file
373
src/ml/features/technical_features.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Features Techniques — Indicateurs utilisés par les traders humains.
|
||||
|
||||
Ce module calcule les indicateurs classiques du trading technique :
|
||||
- RSI (zones, divergences)
|
||||
- MACD (crossovers, histogramme)
|
||||
- Bollinger Bands (squeeze, position relative)
|
||||
- Supports / Résistances (niveaux pivots récents)
|
||||
- Points Pivots (classiques + Fibonacci)
|
||||
- ATR (volatilité normalisée)
|
||||
- Volume (ratio, pics, tendance)
|
||||
- Alignement des EMAs (tendance)
|
||||
- Patterns de chandeliers (marteau, étoile filante, engulfing...)
|
||||
|
||||
Ces features alimentent le MLStrategyModel (XGBoost/LightGBM).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TechnicalFeatureBuilder:
|
||||
"""
|
||||
Construit un DataFrame de features techniques à partir de données OHLCV.
|
||||
|
||||
Usage:
|
||||
builder = TechnicalFeatureBuilder()
|
||||
features = builder.build(df_ohlcv)
|
||||
# features est un DataFrame aligné sur df_ohlcv, prêt pour XGBoost
|
||||
"""
|
||||
|
||||
# Périodes par défaut
|
||||
RSI_PERIOD = 14
|
||||
MACD_FAST = 12
|
||||
MACD_SLOW = 26
|
||||
MACD_SIGNAL = 9
|
||||
BB_PERIOD = 20
|
||||
BB_STD = 2.0
|
||||
ATR_PERIOD = 14
|
||||
EMA_PERIODS = [8, 21, 50, 200]
|
||||
SR_LOOKBACK = 50 # Barres pour détecter S/R
|
||||
SR_TOLERANCE = 0.001 # 0.1% de tolérance pour regrouper les niveaux
|
||||
PIVOT_LOOKBACK = 5 # Barres pour pivot high/low local
|
||||
|
||||
def build(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Calcule toutes les features techniques.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV avec colonnes open/high/low/close/volume
|
||||
|
||||
Returns:
|
||||
DataFrame de features (même index que data, NaN supprimés)
|
||||
"""
|
||||
df = data.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
features = pd.DataFrame(index=df.index)
|
||||
|
||||
# --- Indicateurs de base ---
|
||||
features = self._add_rsi(features, df)
|
||||
features = self._add_macd(features, df)
|
||||
features = self._add_bollinger(features, df)
|
||||
features = self._add_atr(features, df)
|
||||
features = self._add_ema_alignment(features, df)
|
||||
features = self._add_volume(features, df)
|
||||
|
||||
# --- Niveaux de prix ---
|
||||
features = self._add_pivot_points(features, df)
|
||||
features = self._add_support_resistance(features, df)
|
||||
|
||||
# --- Patterns de chandeliers ---
|
||||
features = self._add_candlestick_patterns(features, df)
|
||||
|
||||
# --- Features temporelles ---
|
||||
features = self._add_temporal(features, df)
|
||||
|
||||
# Supprimer les lignes avec trop de NaN (début de série)
|
||||
features = features.dropna(thresh=int(len(features.columns) * 0.7))
|
||||
|
||||
logger.info(f"Features construites : {len(features.columns)} colonnes, {len(features)} lignes")
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# RSI
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_rsi(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
rsi = self._rsi(df['close'], self.RSI_PERIOD)
|
||||
features['rsi'] = rsi
|
||||
features['rsi_oversold'] = (rsi < 30).astype(int)
|
||||
features['rsi_overbought'] = (rsi > 70).astype(int)
|
||||
features['rsi_neutral'] = ((rsi >= 40) & (rsi <= 60)).astype(int)
|
||||
# Pente RSI sur 3 barres
|
||||
features['rsi_slope'] = rsi.diff(3) / 3
|
||||
# Divergence haussière : prix fait lower low mais RSI fait higher low
|
||||
price_ll = df['close'].diff(5) < 0
|
||||
rsi_hl = rsi.diff(5) > 0
|
||||
features['rsi_bullish_div'] = (price_ll & rsi_hl).astype(int)
|
||||
# Divergence baissière
|
||||
price_hh = df['close'].diff(5) > 0
|
||||
rsi_lh = rsi.diff(5) < 0
|
||||
features['rsi_bearish_div'] = (price_hh & rsi_lh).astype(int)
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# MACD
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_macd(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
ema_fast = df['close'].ewm(span=self.MACD_FAST, adjust=False).mean()
|
||||
ema_slow = df['close'].ewm(span=self.MACD_SLOW, adjust=False).mean()
|
||||
macd_line = ema_fast - ema_slow
|
||||
signal = macd_line.ewm(span=self.MACD_SIGNAL, adjust=False).mean()
|
||||
histogram = macd_line - signal
|
||||
|
||||
features['macd'] = macd_line
|
||||
features['macd_signal'] = signal
|
||||
features['macd_hist'] = histogram
|
||||
features['macd_hist_slope'] = histogram.diff(2)
|
||||
# Crossover haussier : MACD passe au-dessus du signal
|
||||
features['macd_cross_up'] = ((macd_line > signal) & (macd_line.shift(1) <= signal.shift(1))).astype(int)
|
||||
# Crossover baissier
|
||||
features['macd_cross_down'] = ((macd_line < signal) & (macd_line.shift(1) >= signal.shift(1))).astype(int)
|
||||
# MACD au-dessus / en-dessous de zéro
|
||||
features['macd_above_zero'] = (macd_line > 0).astype(int)
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Bollinger Bands
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_bollinger(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
sma = df['close'].rolling(self.BB_PERIOD).mean()
|
||||
std = df['close'].rolling(self.BB_PERIOD).std()
|
||||
upper = sma + self.BB_STD * std
|
||||
lower = sma - self.BB_STD * std
|
||||
bw = (upper - lower) / sma # Bandwidth normalisé
|
||||
|
||||
# Position relative dans les bandes [0..1]
|
||||
features['bb_position'] = (df['close'] - lower) / (upper - lower + 1e-10)
|
||||
features['bb_bandwidth'] = bw
|
||||
# Squeeze : bandwidth < percentile 20 des 50 dernières barres
|
||||
features['bb_squeeze'] = (bw < bw.rolling(50).quantile(0.2)).astype(int)
|
||||
# Cassure haussière / baissière
|
||||
features['bb_break_up'] = (df['close'] > upper).astype(int)
|
||||
features['bb_break_down']= (df['close'] < lower).astype(int)
|
||||
# Rebond depuis la bande inférieure
|
||||
features['bb_bounce_low'] = ((df['close'].shift(1) < lower.shift(1)) & (df['close'] > lower)).astype(int)
|
||||
features['bb_bounce_high'] = ((df['close'].shift(1) > upper.shift(1)) & (df['close'] < upper)).astype(int)
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# ATR — Volatilité normalisée
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_atr(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
atr = self._atr(df, self.ATR_PERIOD)
|
||||
features['atr'] = atr
|
||||
features['atr_pct'] = atr / df['close'] # ATR en % du prix
|
||||
features['atr_ratio'] = atr / atr.rolling(50).mean() # Ratio vs moyenne
|
||||
features['volatility_high'] = (features['atr_ratio'] > 1.5).astype(int)
|
||||
features['volatility_low'] = (features['atr_ratio'] < 0.7).astype(int)
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Alignement des EMAs — Tendance
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_ema_alignment(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
emas = {}
|
||||
for p in self.EMA_PERIODS:
|
||||
emas[p] = df['close'].ewm(span=p, adjust=False).mean()
|
||||
features[f'ema_{p}_dist'] = (df['close'] - emas[p]) / emas[p] # Distance % au prix
|
||||
|
||||
# Alignement haussier : EMA8 > EMA21 > EMA50
|
||||
if 8 in emas and 21 in emas and 50 in emas:
|
||||
features['trend_bull'] = ((emas[8] > emas[21]) & (emas[21] > emas[50])).astype(int)
|
||||
features['trend_bear'] = ((emas[8] < emas[21]) & (emas[21] < emas[50])).astype(int)
|
||||
|
||||
# Pente EMA21 sur 5 barres
|
||||
if 21 in emas:
|
||||
features['ema21_slope'] = emas[21].diff(5) / emas[21].shift(5)
|
||||
|
||||
# Prix au-dessus / en-dessous EMA200
|
||||
if 200 in emas:
|
||||
features['above_ema200'] = (df['close'] > emas[200]).astype(int)
|
||||
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Volume
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_volume(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if 'volume' not in df.columns or df['volume'].sum() == 0:
|
||||
# Forex : pas de volume réel — on utilise 1.0 comme neutre
|
||||
features['volume_ratio'] = 1.0
|
||||
features['volume_spike'] = 0
|
||||
features['volume_trend'] = 0.0
|
||||
return features
|
||||
|
||||
vol_ma = df['volume'].rolling(20).mean()
|
||||
features['volume_ratio'] = df['volume'] / (vol_ma + 1e-10)
|
||||
features['volume_spike'] = (features['volume_ratio'] > 2.0).astype(int)
|
||||
features['volume_trend'] = df['volume'].rolling(5).mean().diff(5)
|
||||
# OBV simplifié
|
||||
obv = (np.sign(df['close'].diff()) * df['volume']).cumsum()
|
||||
features['obv_slope'] = obv.diff(5)
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Points Pivots (classiques + Fibonacci)
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_pivot_points(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Calcule les pivots de la session précédente (H+L+C)/3.
|
||||
Pour des données intraday, utilise la barre précédente comme proxy.
|
||||
"""
|
||||
prev_high = df['high'].shift(1)
|
||||
prev_low = df['low'].shift(1)
|
||||
prev_close = df['close'].shift(1)
|
||||
|
||||
pivot = (prev_high + prev_low + prev_close) / 3
|
||||
|
||||
r1 = 2 * pivot - prev_low
|
||||
s1 = 2 * pivot - prev_high
|
||||
r2 = pivot + (prev_high - prev_low)
|
||||
s2 = pivot - (prev_high - prev_low)
|
||||
|
||||
# Fibonacci
|
||||
diff = prev_high - prev_low
|
||||
r1f = pivot + 0.382 * diff
|
||||
r2f = pivot + 0.618 * diff
|
||||
r3f = pivot + 1.000 * diff
|
||||
s1f = pivot - 0.382 * diff
|
||||
s2f = pivot - 0.618 * diff
|
||||
s3f = pivot - 1.000 * diff
|
||||
|
||||
close = df['close']
|
||||
atr = self._atr(df, self.ATR_PERIOD)
|
||||
|
||||
# Distance normalisée par ATR à chaque niveau
|
||||
for name, level in [('pivot', pivot), ('r1', r1), ('s1', s1),
|
||||
('r2', r2), ('s2', s2), ('r1f', r1f), ('s1f', s1f)]:
|
||||
features[f'dist_{name}'] = (close - level) / (atr + 1e-10)
|
||||
|
||||
# Proche d'un pivot clé (< 0.5 ATR)
|
||||
features['near_pivot'] = (abs(close - pivot) < 0.5 * atr).astype(int)
|
||||
features['near_r1'] = (abs(close - r1) < 0.5 * atr).astype(int)
|
||||
features['near_s1'] = (abs(close - s1) < 0.5 * atr).astype(int)
|
||||
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Supports / Résistances (niveaux pivots locaux)
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_support_resistance(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Détecte les niveaux S/R comme les plus hauts/bas locaux récents.
|
||||
Calcule la distance du prix actuel aux niveaux les plus proches.
|
||||
"""
|
||||
atr = self._atr(df, self.ATR_PERIOD)
|
||||
n = self.PIVOT_LOOKBACK
|
||||
|
||||
# Pivot high : high local (plus haut que les n barres avant et après)
|
||||
pivot_highs = df['high'][(df['high'] == df['high'].rolling(2 * n + 1, center=True).max())]
|
||||
pivot_lows = df['low'][ (df['low'] == df['low'].rolling(2 * n + 1, center=True).min())]
|
||||
|
||||
dist_to_res = pd.Series(np.nan, index=df.index)
|
||||
dist_to_sup = pd.Series(np.nan, index=df.index)
|
||||
|
||||
for i in range(len(df)):
|
||||
close_val = df['close'].iloc[i]
|
||||
atr_val = atr.iloc[i] if not np.isnan(atr.iloc[i]) else 0.001
|
||||
|
||||
# Résistances au-dessus du prix dans les SR_LOOKBACK dernières barres
|
||||
past_highs = pivot_highs.iloc[max(0, i - self.SR_LOOKBACK):i]
|
||||
resistances = past_highs[past_highs > close_val]
|
||||
if not resistances.empty:
|
||||
nearest_res = resistances.min()
|
||||
dist_to_res.iloc[i] = (nearest_res - close_val) / atr_val
|
||||
|
||||
# Supports en-dessous
|
||||
past_lows = pivot_lows.iloc[max(0, i - self.SR_LOOKBACK):i]
|
||||
supports = past_lows[past_lows < close_val]
|
||||
if not supports.empty:
|
||||
nearest_sup = supports.max()
|
||||
dist_to_sup.iloc[i] = (close_val - nearest_sup) / atr_val
|
||||
|
||||
features['dist_to_resistance'] = dist_to_res
|
||||
features['dist_to_support'] = dist_to_sup
|
||||
# Rebond : prix proche du support (< 1 ATR)
|
||||
features['bounce_from_support'] = (dist_to_sup < 1.0).astype(int)
|
||||
features['rejection_at_resistance']= (dist_to_res < 1.0).astype(int)
|
||||
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Patterns de chandeliers
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_candlestick_patterns(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
o, h, l, c = df['open'], df['high'], df['low'], df['close']
|
||||
body = abs(c - o)
|
||||
candle = h - l
|
||||
atr = self._atr(df, self.ATR_PERIOD)
|
||||
|
||||
# Corps relatif
|
||||
features['body_ratio'] = body / (candle + 1e-10) # 0 = doji, 1 = marubozu
|
||||
features['is_bullish'] = (c > o).astype(int)
|
||||
features['is_bearish'] = (c < o).astype(int)
|
||||
|
||||
# Doji (corps < 10% de la mèche)
|
||||
features['doji'] = (body < 0.1 * candle).astype(int)
|
||||
|
||||
# Marteau : corps petit en haut, longue mèche basse
|
||||
lower_wick = pd.Series(np.where(c > o, o - l, c - l), index=df.index)
|
||||
upper_wick = pd.Series(np.where(c > o, h - c, h - o), index=df.index)
|
||||
features['hammer'] = ((lower_wick > 2 * body) & (upper_wick < 0.3 * body)).astype(int)
|
||||
# Étoile filante : corps petit en bas, longue mèche haute
|
||||
features['shooting_star'] = ((upper_wick > 2 * body) & (lower_wick < 0.3 * body)).astype(int)
|
||||
|
||||
# Engulfing haussier
|
||||
prev_body = abs(df['close'].shift(1) - df['open'].shift(1))
|
||||
prev_bear = df['close'].shift(1) < df['open'].shift(1)
|
||||
features['bullish_engulfing'] = (
|
||||
prev_bear & (c > o) & (o < df['close'].shift(1)) & (c > df['open'].shift(1))
|
||||
).astype(int)
|
||||
# Engulfing baissier
|
||||
prev_bull = df['close'].shift(1) > df['open'].shift(1)
|
||||
features['bearish_engulfing'] = (
|
||||
prev_bull & (c < o) & (o > df['close'].shift(1)) & (c < df['open'].shift(1))
|
||||
).astype(int)
|
||||
|
||||
# Taille relative de la bougie vs ATR
|
||||
features['candle_size_ratio'] = candle / (atr + 1e-10)
|
||||
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Features temporelles
|
||||
# -------------------------------------------------------------------------
|
||||
def _add_temporal(self, features: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
|
||||
idx = df.index
|
||||
if hasattr(idx, 'hour'):
|
||||
features['hour'] = idx.hour
|
||||
features['is_london'] = ((idx.hour >= 8) & (idx.hour < 16)).astype(int)
|
||||
features['is_ny'] = ((idx.hour >= 13) & (idx.hour < 21)).astype(int)
|
||||
features['is_overlap'] = ((idx.hour >= 13) & (idx.hour < 16)).astype(int) # London+NY
|
||||
if hasattr(idx, 'dayofweek'):
|
||||
features['day_of_week'] = idx.dayofweek
|
||||
features['is_monday'] = (idx.dayofweek == 0).astype(int)
|
||||
features['is_friday'] = (idx.dayofweek == 4).astype(int)
|
||||
return features
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _rsi(series: pd.Series, period: int) -> pd.Series:
|
||||
delta = series.diff()
|
||||
gain = delta.clip(lower=0).rolling(period).mean()
|
||||
loss = (-delta.clip(upper=0)).rolling(period).mean()
|
||||
rs = gain / (loss + 1e-10)
|
||||
return 100 - (100 / (1 + rs))
|
||||
|
||||
@staticmethod
|
||||
def _atr(df: pd.DataFrame, period: int) -> pd.Series:
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
return tr.rolling(period).mean()
|
||||
|
||||
def get_feature_names(self) -> List[str]:
|
||||
"""Retourne les noms de toutes les features (après un premier build)."""
|
||||
return self._last_feature_names if hasattr(self, '_last_feature_names') else []
|
||||
452
src/ml/ml_strategy_model.py
Normal file
452
src/ml/ml_strategy_model.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
ML Strategy Model — Modèle XGBoost/LightGBM qui apprend à trader.
|
||||
|
||||
Ce module entraîne un classificateur supervisé sur des features techniques
|
||||
classiques (RSI, MACD, S/R, pivots, patterns chandeliers...) pour prédire
|
||||
si la prochaine opportunité est LONG, SHORT ou NEUTRAL.
|
||||
|
||||
Pipeline :
|
||||
1. Chargement données OHLCV (via DataService ou fichier)
|
||||
2. Construction features (TechnicalFeatureBuilder)
|
||||
3. Génération labels (LabelGenerator — forward simulation)
|
||||
4. Entraînement XGBoost (ou LightGBM)
|
||||
5. Évaluation : precision, recall, F1 par classe + walk-forward
|
||||
6. Sauvegarde modèle (joblib) + métadonnées JSON
|
||||
|
||||
Usage:
|
||||
model = MLStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv)
|
||||
signal = model.predict(df_recent) # retourne 1, -1 ou 0 + confidence
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.ml.features.technical_features import TechnicalFeatureBuilder
|
||||
from src.ml.features.label_generator import LabelGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de sauvegarde des modèles
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent / "models" / "ml_strategy"
|
||||
|
||||
try:
|
||||
import xgboost as xgb
|
||||
XGBOOST_AVAILABLE = True
|
||||
except ImportError:
|
||||
XGBOOST_AVAILABLE = False
|
||||
logger.warning("XGBoost non disponible — utilisation de RandomForest de fallback")
|
||||
|
||||
try:
|
||||
import lightgbm as lgb
|
||||
LIGHTGBM_AVAILABLE = True
|
||||
except ImportError:
|
||||
LIGHTGBM_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.metrics import classification_report, precision_recall_fscore_support
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
import joblib
|
||||
SKLEARN_AVAILABLE = True
|
||||
except ImportError:
|
||||
SKLEARN_AVAILABLE = False
|
||||
logger.error("scikit-learn non disponible — MLStrategyModel ne peut pas fonctionner")
|
||||
|
||||
|
||||
class MLStrategyModel:
|
||||
"""
|
||||
Modèle ML qui apprend les patterns de trading humains.
|
||||
|
||||
Le modèle :
|
||||
- Apprend depuis des features TA classiques (RSI, MACD, S/R, pivots...)
|
||||
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||
- Donne un score de confiance [0..1] par prédiction
|
||||
- Se sauvegarde sur disque pour être rechargé sans ré-entraînement
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h', '15m')
|
||||
model_type: 'xgboost', 'lightgbm' ou 'random_forest'
|
||||
tp_atr_mult: Multiplicateur ATR pour le TP lors de la génération des labels
|
||||
sl_atr_mult: Multiplicateur ATR pour le SL lors de la génération des labels
|
||||
horizon: Nombre de barres pour évaluer TP/SL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
model_type: str = 'xgboost',
|
||||
tp_atr_mult: float = 2.0,
|
||||
sl_atr_mult: float = 1.0,
|
||||
horizon: int = 30,
|
||||
min_confidence: float = 0.55,
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.model_type = model_type
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.horizon = horizon
|
||||
self.min_confidence = min_confidence
|
||||
|
||||
self.model = None
|
||||
self.scaler = None
|
||||
self.feature_names: List[str] = []
|
||||
self.is_trained = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement
|
||||
# -------------------------------------------------------------------------
|
||||
def train(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Entraîne le modèle sur les données OHLCV fournies.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV (au moins 200 barres recommandées)
|
||||
|
||||
Returns:
|
||||
Dict avec métriques d'entraînement et validation
|
||||
"""
|
||||
if not SKLEARN_AVAILABLE:
|
||||
return {'error': 'scikit-learn non disponible'}
|
||||
|
||||
logger.info(f"Début entraînement MLStrategyModel pour {self.symbol}/{self.timeframe}")
|
||||
logger.info(f" Données : {len(data)} barres, type={self.model_type}")
|
||||
|
||||
# 1. Features
|
||||
builder = TechnicalFeatureBuilder()
|
||||
features = builder.build(data)
|
||||
if len(features) < 100:
|
||||
return {'error': f'Pas assez de données après feature engineering : {len(features)} barres'}
|
||||
|
||||
# 2. Labels (ATR-based pour s'adapter à la volatilité)
|
||||
gen = LabelGenerator(horizon=self.horizon)
|
||||
labels = gen.generate_atr_based(
|
||||
data,
|
||||
atr_tp_mult=self.tp_atr_mult,
|
||||
atr_sl_mult=self.sl_atr_mult,
|
||||
)
|
||||
|
||||
# Aligner features et labels sur le même index
|
||||
common_idx = features.index.intersection(labels.index)
|
||||
X = features.loc[common_idx].fillna(0)
|
||||
y = labels.loc[common_idx]
|
||||
|
||||
# Supprimer les barres de fin (labels non fiables car horizon tronqué)
|
||||
X = X.iloc[:-self.horizon]
|
||||
y = y.iloc[:-self.horizon]
|
||||
|
||||
if len(X) < 50:
|
||||
return {'error': f'Trop peu de données : {len(X)} échantillons'}
|
||||
|
||||
self.feature_names = list(X.columns)
|
||||
logger.info(f" {len(X)} échantillons, {len(self.feature_names)} features")
|
||||
logger.info(f" Distribution : LONG={( y==1).sum()}, SHORT={(y==-1).sum()}, NEUTRAL={(y==0).sum()}")
|
||||
|
||||
# 3. Walk-forward cross-validation (3 folds temporels)
|
||||
wf_metrics = self._walk_forward_eval(X, y, n_splits=3)
|
||||
|
||||
# 4. Entraînement sur la totalité des données
|
||||
self.scaler = StandardScaler()
|
||||
X_scaled = self.scaler.fit_transform(X)
|
||||
|
||||
self.model = self._build_model()
|
||||
self.model.fit(X_scaled, y)
|
||||
self.is_trained = True
|
||||
|
||||
# 5. Évaluation finale (in-sample — indicative)
|
||||
y_pred = self.model.predict(X_scaled)
|
||||
report = classification_report(y, y_pred, labels=[-1, 0, 1],
|
||||
target_names=['SHORT', 'NEUTRAL', 'LONG'],
|
||||
output_dict=True, zero_division=0)
|
||||
|
||||
self.metadata = {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'model_type': self.model_type,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'n_samples': len(X),
|
||||
'n_features': len(self.feature_names),
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'label_dist': {
|
||||
'long': int((y == 1).sum()),
|
||||
'short': int((y == -1).sum()),
|
||||
'neutral': int((y == 0).sum()),
|
||||
},
|
||||
'wf_metrics': wf_metrics,
|
||||
'train_report': report,
|
||||
}
|
||||
|
||||
# 6. Sauvegarde
|
||||
self.save()
|
||||
|
||||
logger.info(f"Entraînement terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}")
|
||||
return self.metadata
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Prédiction
|
||||
# -------------------------------------------------------------------------
|
||||
def predict(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal pour la dernière barre disponible.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV récent (minimum 200 barres pour les indicateurs)
|
||||
|
||||
Returns:
|
||||
Dict : {
|
||||
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
|
||||
'confidence': float [0..1],
|
||||
'probas': {'long': float, 'short': float, 'neutral': float},
|
||||
'tradeable': bool (confidence >= min_confidence),
|
||||
}
|
||||
"""
|
||||
if not self.is_trained or self.model is None:
|
||||
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'Modèle non entraîné'}
|
||||
|
||||
builder = TechnicalFeatureBuilder()
|
||||
features = builder.build(data)
|
||||
|
||||
if features.empty:
|
||||
return {'signal': 0, 'confidence': 0.0, 'tradeable': False, 'error': 'Features vides'}
|
||||
|
||||
# Dernière ligne = barre actuelle
|
||||
last = features.iloc[[-1]]
|
||||
|
||||
# Aligner les colonnes (features manquantes → 0)
|
||||
for col in self.feature_names:
|
||||
if col not in last.columns:
|
||||
last[col] = 0.0
|
||||
last = last[self.feature_names].fillna(0)
|
||||
|
||||
X_scaled = self.scaler.transform(last)
|
||||
pred = self.model.predict(X_scaled)[0]
|
||||
|
||||
# Probabilités si disponibles
|
||||
probas = {'long': 0.0, 'short': 0.0, 'neutral': 1.0}
|
||||
confidence = 0.0
|
||||
if hasattr(self.model, 'predict_proba'):
|
||||
proba_arr = self.model.predict_proba(X_scaled)[0]
|
||||
classes = list(self.model.classes_)
|
||||
prob_map = {c: p for c, p in zip(classes, proba_arr)}
|
||||
probas = {
|
||||
'long': float(prob_map.get(1, 0.0)),
|
||||
'short': float(prob_map.get(-1, 0.0)),
|
||||
'neutral': float(prob_map.get(0, 1.0)),
|
||||
}
|
||||
confidence = float(max(probas['long'], probas['short']))
|
||||
|
||||
return {
|
||||
'signal': int(pred),
|
||||
'confidence': confidence,
|
||||
'probas': probas,
|
||||
'tradeable': confidence >= self.min_confidence and pred != 0,
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sauvegarde / Chargement
|
||||
# -------------------------------------------------------------------------
|
||||
def save(self) -> Path:
|
||||
"""Sauvegarde le modèle et ses métadonnées sur disque."""
|
||||
if not SKLEARN_AVAILABLE or not self.is_trained:
|
||||
raise RuntimeError("Modèle non entraîné")
|
||||
|
||||
model_id = f"{self.symbol}_{self.timeframe}_{self.model_type}"
|
||||
model_path = MODELS_DIR / f"{model_id}.joblib"
|
||||
meta_path = MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
joblib.dump({
|
||||
'model': self.model,
|
||||
'scaler': self.scaler,
|
||||
'feature_names': self.feature_names,
|
||||
'metadata': self.metadata,
|
||||
'config': {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'model_type': self.model_type,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'min_confidence': self.min_confidence,
|
||||
},
|
||||
}, model_path)
|
||||
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Modèle sauvegardé : {model_path}")
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str, model_type: str = 'xgboost') -> 'MLStrategyModel':
|
||||
"""
|
||||
Charge un modèle existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
model_type: Type de modèle
|
||||
|
||||
Returns:
|
||||
Instance MLStrategyModel prête à prédire
|
||||
|
||||
Raises:
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not SKLEARN_AVAILABLE:
|
||||
raise RuntimeError("scikit-learn non disponible")
|
||||
|
||||
model_id = f"{symbol}_{timeframe}_{model_type}"
|
||||
model_path = MODELS_DIR / f"{model_id}.joblib"
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle non trouvé : {model_path}")
|
||||
|
||||
data = joblib.load(model_path)
|
||||
cfg = data.get('config', {})
|
||||
|
||||
instance = cls(
|
||||
symbol = cfg.get('symbol', symbol),
|
||||
timeframe = cfg.get('timeframe', timeframe),
|
||||
model_type = cfg.get('model_type', model_type),
|
||||
tp_atr_mult = cfg.get('tp_atr_mult', 2.0),
|
||||
sl_atr_mult = cfg.get('sl_atr_mult', 1.0),
|
||||
horizon = cfg.get('horizon', 30),
|
||||
min_confidence = cfg.get('min_confidence', 0.55),
|
||||
)
|
||||
instance.model = data['model']
|
||||
instance.scaler = data['scaler']
|
||||
instance.feature_names = data['feature_names']
|
||||
instance.metadata = data.get('metadata', {})
|
||||
instance.is_trained = True
|
||||
|
||||
logger.info(f"Modèle chargé depuis {model_path}")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def list_trained_models() -> List[Dict]:
|
||||
"""Retourne la liste des modèles entraînés disponibles."""
|
||||
if not MODELS_DIR.exists():
|
||||
return []
|
||||
models = []
|
||||
for f in MODELS_DIR.glob("*_meta.json"):
|
||||
try:
|
||||
with open(f) as fp:
|
||||
meta = json.load(fp)
|
||||
parts = f.stem.replace('_meta', '').split('_')
|
||||
models.append({
|
||||
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||
'model_type': meta.get('model_type', parts[2] if len(parts) > 2 else '?'),
|
||||
'trained_at': meta.get('trained_at', '?'),
|
||||
'n_samples': meta.get('n_samples', 0),
|
||||
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return models
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Walk-forward évaluation
|
||||
# -------------------------------------------------------------------------
|
||||
def _walk_forward_eval(self, X: pd.DataFrame, y: pd.Series, n_splits: int = 3) -> Dict:
|
||||
"""Évalue le modèle en cross-validation temporelle."""
|
||||
tscv = TimeSeriesSplit(n_splits=n_splits)
|
||||
accuracies, precisions, recalls = [], [], []
|
||||
|
||||
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
|
||||
X_tr, X_te = X.iloc[train_idx], X.iloc[test_idx]
|
||||
y_tr, y_te = y.iloc[train_idx], y.iloc[test_idx]
|
||||
|
||||
scaler = StandardScaler()
|
||||
X_tr_s = scaler.fit_transform(X_tr)
|
||||
X_te_s = scaler.transform(X_te)
|
||||
|
||||
m = self._build_model()
|
||||
m.fit(X_tr_s, y_tr)
|
||||
y_pred = m.predict(X_te_s)
|
||||
|
||||
acc = (y_pred == y_te.values).mean()
|
||||
# Précision/Recall sur les signaux directionnels uniquement
|
||||
mask = (y_te != 0) | (y_pred != 0)
|
||||
prec, rec, _, _ = precision_recall_fscore_support(
|
||||
y_te[mask], y_pred[mask], average='macro', zero_division=0
|
||||
) if mask.sum() > 0 else (0, 0, 0, 0)
|
||||
|
||||
accuracies.append(acc)
|
||||
precisions.append(prec)
|
||||
recalls.append(rec)
|
||||
logger.info(f" Fold {fold+1}/{n_splits} : acc={acc:.2%}, prec={prec:.2%}, rec={rec:.2%}")
|
||||
|
||||
return {
|
||||
'avg_accuracy': float(np.mean(accuracies)),
|
||||
'avg_precision': float(np.mean(precisions)),
|
||||
'avg_recall': float(np.mean(recalls)),
|
||||
'fold_accuracies': [float(a) for a in accuracies],
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Construction du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
def _build_model(self):
|
||||
"""Instancie le modèle selon le type configuré."""
|
||||
if self.model_type == 'xgboost' and XGBOOST_AVAILABLE:
|
||||
return xgb.XGBClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
learning_rate=0.05,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
use_label_encoder=False,
|
||||
eval_metric='mlogloss',
|
||||
random_state=42,
|
||||
verbosity=0,
|
||||
)
|
||||
elif self.model_type == 'lightgbm' and LIGHTGBM_AVAILABLE:
|
||||
return lgb.LGBMClassifier(
|
||||
n_estimators=300,
|
||||
max_depth=5,
|
||||
learning_rate=0.05,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
)
|
||||
elif SKLEARN_AVAILABLE:
|
||||
logger.warning(f"'{self.model_type}' non disponible — fallback RandomForest")
|
||||
return RandomForestClassifier(
|
||||
n_estimators=200,
|
||||
max_depth=8,
|
||||
min_samples_split=10,
|
||||
random_state=42,
|
||||
n_jobs=-1,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Aucun algorithme ML disponible")
|
||||
|
||||
def get_feature_importance(self, top_n: int = 20) -> List[Dict]:
|
||||
"""Retourne les N features les plus importantes."""
|
||||
if not self.is_trained or not hasattr(self.model, 'feature_importances_'):
|
||||
return []
|
||||
importances = self.model.feature_importances_
|
||||
pairs = sorted(
|
||||
zip(self.feature_names, importances),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
return [{'feature': f, 'importance': float(i)} for f, i in pairs[:top_n]]
|
||||
3
src/strategies/ml_driven/__init__.py
Normal file
3
src/strategies/ml_driven/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ml_strategy import MLDrivenStrategy
|
||||
|
||||
__all__ = ['MLDrivenStrategy']
|
||||
227
src/strategies/ml_driven/ml_strategy.py
Normal file
227
src/strategies/ml_driven/ml_strategy.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
ML-Driven Strategy — Stratégie pilotée par apprentissage automatique.
|
||||
|
||||
Cette stratégie remplace les règles codées en dur par un modèle ML
|
||||
(XGBoost/LightGBM) entraîné sur des features techniques classiques :
|
||||
RSI, MACD, supports/résistances, pivots, patterns chandeliers, etc.
|
||||
|
||||
Fonctionnement :
|
||||
1. Le modèle est chargé depuis le disque (entraîné via POST /trading/train)
|
||||
2. À chaque barre, les features sont calculées sur les données historiques
|
||||
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
|
||||
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||
|
||||
Intégration :
|
||||
- Compatible avec StrategyEngine (même interface que ScalpingStrategy)
|
||||
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||
- Le RiskManager applique les mêmes contrôles que pour les stratégies classiques
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||
from src.ml.ml_strategy_model import MLStrategyModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLDrivenStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading pilotée par un modèle ML pré-entraîné.
|
||||
|
||||
Le modèle apprend les patterns TA utilisés par les traders humains :
|
||||
- Rebonds sur supports/résistances
|
||||
- Divergences RSI / MACD
|
||||
- Squeeze Bollinger + expansion
|
||||
- Alignement EMAs (tendance)
|
||||
- Patterns chandeliers (marteau, engulfing...)
|
||||
- Proximité des pivots
|
||||
|
||||
Args:
|
||||
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||
|
||||
Config keys supplémentaires (optionnelles) :
|
||||
model_type: 'xgboost' | 'lightgbm' | 'random_forest' (défaut: 'xgboost')
|
||||
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'ml_driven'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.model_type = config.get('model_type', 'xgboost')
|
||||
self.min_confidence = config.get('min_confidence', 0.55)
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
|
||||
self.ml_model: Optional[MLStrategyModel] = None
|
||||
|
||||
# Tentative de chargement automatique du modèle existant
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_model()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# -------------------------------------------------------------------------
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal de trading via le modèle ML.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum 200 barres)
|
||||
|
||||
Returns:
|
||||
Signal si le modèle est confiant, None sinon
|
||||
"""
|
||||
if self.ml_model is None or not self.ml_model.is_trained:
|
||||
logger.debug("ML Strategy : modèle non chargé, aucun signal")
|
||||
return None
|
||||
|
||||
if len(market_data) < 50:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.ml_model.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"ML Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol = self.symbol,
|
||||
direction = direction,
|
||||
entry_price = last_close,
|
||||
stop_loss = stop_loss,
|
||||
take_profit = take_profit,
|
||||
confidence = confidence,
|
||||
timestamp = datetime.now(timezone.utc),
|
||||
strategy = self.STRATEGY_NAME,
|
||||
metadata = {
|
||||
'probas': result.get('probas', {}),
|
||||
'model_type': self.model_type,
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"ML Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — les features sont calculées dans predict()."""
|
||||
return data
|
||||
|
||||
def update_params(self, params: Dict) -> None:
|
||||
"""Mise à jour dynamique des paramètres (depuis Optuna ou API)."""
|
||||
if 'min_confidence' in params:
|
||||
self.min_confidence = params['min_confidence']
|
||||
if self.ml_model:
|
||||
self.ml_model.min_confidence = params['min_confidence']
|
||||
if 'tp_atr_mult' in params:
|
||||
self.tp_atr_mult = params['tp_atr_mult']
|
||||
if 'sl_atr_mult' in params:
|
||||
self.sl_atr_mult = params['sl_atr_mult']
|
||||
logger.info(f"ML Strategy params mis à jour : {params}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Gestion du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Charge un modèle depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (défaut: self.symbol)
|
||||
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||
|
||||
Returns:
|
||||
True si chargement réussi
|
||||
"""
|
||||
sym = symbol or self.symbol
|
||||
tf = timeframe or self.config.timeframe
|
||||
try:
|
||||
self.ml_model = MLStrategyModel.load(sym, tf, self.model_type)
|
||||
logger.info(f"Modèle ML chargé : {sym}/{tf}/{self.model_type}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.info(f"Aucun modèle ML trouvé pour {sym}/{tf}/{self.model_type}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement modèle : {e}")
|
||||
return False
|
||||
|
||||
def attach_model(self, model: MLStrategyModel) -> None:
|
||||
"""Attache directement un modèle (après entraînement via API)."""
|
||||
self.ml_model = model
|
||||
self.symbol = model.symbol
|
||||
logger.info(f"Modèle ML attaché : {model.symbol}/{model.timeframe}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Retourne True si le modèle est chargé et entraîné."""
|
||||
return self.ml_model is not None and self.ml_model.is_trained
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Retourne les métadonnées du modèle actif."""
|
||||
if not self.is_ready():
|
||||
return {'status': 'non entraîné'}
|
||||
meta = self.ml_model.metadata.copy()
|
||||
meta['is_ready'] = True
|
||||
meta['feature_importance'] = self.ml_model.get_feature_importance(top_n=10)
|
||||
return meta
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _try_load_model(self) -> None:
|
||||
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||
try:
|
||||
self.load_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
Reference in New Issue
Block a user