feat: Phase 4c — CNN + Ensemble architecture (multi-signal trading)
## Nouveaux modules
### src/ml/cnn/
- candlestick_encoder.py : CandlestickEncoder, fenêtres OHLCV z-score (N, 64, 5)
- cnn_model.py : TradingCNN — 3 blocs Conv1D(5→32→64→128) + BN + ReLU + GlobalAvgPool
- cnn_strategy_model.py : CNNStrategyModel, API identique à MLStrategyModel (train/predict/save/load)
### src/ml/ensemble/
- ensemble_model.py : EnsembleModel, poids {xgboost:0.40, cnn:0.60}, accord requis entre modèles
### src/strategies/cnn_driven/
- cnn_strategy.py : CNNDrivenStrategy(BaseStrategy), SL/TP ATR-based, fallback CNN_AVAILABLE=False
### src/strategies/ensemble/
- ensemble_strategy.py : EnsembleStrategy(BaseStrategy), auto-load XGBoost + CNN au démarrage
## Modifications
- trading.py : routes POST /train-cnn, GET /train-cnn/{job_id}, GET /cnn-models,
POST /ensemble/configure, GET /ensemble/status + fix bugs (logging, _get_data_service, period_map)
- strategy_engine.py : support 'ml_driven' dans load_strategy()
- docker/requirements/api.txt : ajout torch>=2.0.0 + dépendances ML (scikit-learn, xgboost, lightgbm)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
252
src/ml/ensemble/ensemble_model.py
Normal file
252
src/ml/ensemble/ensemble_model.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Ensemble Model — Combine plusieurs modèles ML pour un signal de trading robuste.
|
||||
|
||||
L'EnsembleModel agrège les prédictions de modèles indépendants (XGBoost, CNN,
|
||||
et plus tard RL) via une moyenne pondérée. Un signal n'est émis que si les
|
||||
modèles actifs sont en accord ET que le score pondéré dépasse un seuil.
|
||||
|
||||
Duck typing : ce module n'importe PAS directement MLStrategyModel ni
|
||||
CNNStrategyModel. Tout objet exposant `.predict(df)` → dict et `.is_trained`
|
||||
→ bool est compatible.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnsembleModel:
|
||||
"""
|
||||
Combine plusieurs modèles ML pour produire un signal de trading robuste.
|
||||
|
||||
Logique :
|
||||
- Chaque modèle prédit indépendamment (signal + confidence)
|
||||
- Score final = somme pondérée des confidences pour les modèles en accord
|
||||
- Signal validé uniquement si :
|
||||
1. Au moins 2 modèles actifs sont en accord sur la direction
|
||||
2. Score pondéré >= min_confidence
|
||||
|
||||
Poids par défaut : xgboost=0.40, cnn=0.60 (CNN légèrement favorisé car
|
||||
il voit les données brutes sans biais de feature engineering)
|
||||
"""
|
||||
|
||||
DEFAULT_WEIGHTS = {
|
||||
'xgboost': 0.40,
|
||||
'cnn': 0.60,
|
||||
'rl': 0.00, # Réservé Phase 4d
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights: Optional[Dict[str, float]] = None,
|
||||
min_confidence: float = 0.60,
|
||||
require_agreement: bool = True,
|
||||
):
|
||||
self.weights = dict(weights) if weights else dict(self.DEFAULT_WEIGHTS)
|
||||
self.min_confidence = min_confidence
|
||||
self.require_agreement = require_agreement
|
||||
|
||||
# Modèles attachés (duck typing : .predict(df), .is_trained)
|
||||
self._models: Dict[str, Any] = {}
|
||||
|
||||
logger.info(
|
||||
f"EnsembleModel initialisé — poids={self.weights}, "
|
||||
f"seuil={self.min_confidence}, accord_requis={self.require_agreement}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachement des modèles
|
||||
# ------------------------------------------------------------------
|
||||
def attach_xgboost(self, model) -> None:
|
||||
"""Attache un MLStrategyModel entraîné."""
|
||||
self._attach('xgboost', model)
|
||||
|
||||
def attach_cnn(self, model) -> None:
|
||||
"""Attache un CNNStrategyModel entraîné."""
|
||||
self._attach('cnn', model)
|
||||
|
||||
def attach_rl(self, model) -> None:
|
||||
"""Attache un agent RL (Phase 4d)."""
|
||||
self._attach('rl', model)
|
||||
|
||||
def _attach(self, name: str, model) -> None:
|
||||
"""Attache un modèle générique avec vérification duck typing."""
|
||||
if not hasattr(model, 'predict') or not callable(model.predict):
|
||||
raise ValueError(f"Le modèle '{name}' doit exposer une méthode predict()")
|
||||
if not hasattr(model, 'is_trained'):
|
||||
raise ValueError(f"Le modèle '{name}' doit exposer un attribut is_trained")
|
||||
self._models[name] = model
|
||||
# Ajouter le poids par défaut s'il n'existe pas
|
||||
if name not in self.weights:
|
||||
self.weights[name] = 0.0
|
||||
logger.warning(f"Poids pour '{name}' non défini — initialisé à 0.0")
|
||||
logger.info(f"Modèle '{name}' attaché (is_trained={model.is_trained})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prédiction combinée
|
||||
# ------------------------------------------------------------------
|
||||
def predict(self, df: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal combiné à partir de tous les modèles actifs.
|
||||
|
||||
Returns:
|
||||
{
|
||||
'signal': int, # 1 LONG, -1 SHORT, 0 NEUTRAL
|
||||
'confidence': float, # score pondéré [0..1]
|
||||
'tradeable': bool,
|
||||
'agreement': bool, # True si tous les modèles actifs concordent
|
||||
'components': dict, # résultats individuels par modèle
|
||||
}
|
||||
"""
|
||||
components: Dict[str, Dict] = {}
|
||||
neutral_result = {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': False, 'components': components,
|
||||
}
|
||||
|
||||
# 1. Collecter les prédictions des modèles disponibles et entraînés
|
||||
for name, model in self._models.items():
|
||||
if not model.is_trained:
|
||||
logger.debug(f"Ensemble : modèle '{name}' non entraîné, ignoré")
|
||||
continue
|
||||
if self.weights.get(name, 0.0) <= 0.0:
|
||||
logger.debug(f"Ensemble : modèle '{name}' poids=0, ignoré")
|
||||
continue
|
||||
try:
|
||||
result = model.predict(df)
|
||||
components[name] = {
|
||||
'signal': result.get('signal', 0),
|
||||
'confidence': result.get('confidence', 0.0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Ensemble : erreur predict '{name}' — {e}")
|
||||
continue
|
||||
|
||||
if not components:
|
||||
logger.debug("Ensemble : aucun modèle actif n'a produit de prédiction")
|
||||
neutral_result['components'] = components
|
||||
return neutral_result
|
||||
|
||||
# 2. Filtrer les signaux non-neutres
|
||||
directional = {
|
||||
k: v for k, v in components.items() if v['signal'] != 0
|
||||
}
|
||||
|
||||
if not directional:
|
||||
# Tous les modèles sont neutres
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': True, 'components': components,
|
||||
}
|
||||
|
||||
# 3. Vérifier l'accord entre modèles directionnels
|
||||
directions = set(v['signal'] for v in directional.values())
|
||||
agreement = len(directions) == 1
|
||||
|
||||
if self.require_agreement and not agreement:
|
||||
logger.debug(
|
||||
f"Ensemble : désaccord entre modèles — {directional}"
|
||||
)
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': False, 'components': components,
|
||||
}
|
||||
|
||||
# 4. Vérifier qu'au moins 2 modèles actifs sont en accord
|
||||
if len(directional) < 2:
|
||||
logger.debug("Ensemble : un seul modèle directionnel, signal insuffisant")
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': True, 'components': components,
|
||||
}
|
||||
|
||||
# 5. Calculer le score pondéré (normalisé sur les modèles actifs)
|
||||
consensus_dir = directions.pop() # direction unique
|
||||
total_weight = sum(self.weights.get(k, 0.0) for k in directional)
|
||||
|
||||
if total_weight <= 0:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'agreement': agreement, 'components': components,
|
||||
}
|
||||
|
||||
weighted_score = sum(
|
||||
self.weights.get(k, 0.0) * v['confidence']
|
||||
for k, v in directional.items()
|
||||
) / total_weight
|
||||
|
||||
# 6. Signal final
|
||||
tradeable = weighted_score >= self.min_confidence
|
||||
|
||||
logger.info(
|
||||
f"Ensemble : direction={'LONG' if consensus_dir == 1 else 'SHORT'} | "
|
||||
f"score={weighted_score:.2%} | accord={agreement} | tradeable={tradeable}"
|
||||
)
|
||||
|
||||
return {
|
||||
'signal': consensus_dir,
|
||||
'confidence': weighted_score,
|
||||
'tradeable': tradeable,
|
||||
'agreement': agreement,
|
||||
'components': components,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Statut et configuration
|
||||
# ------------------------------------------------------------------
|
||||
def is_ready(self) -> bool:
|
||||
"""True si au moins 2 modèles sont attachés et entraînés."""
|
||||
trained = sum(
|
||||
1 for m in self._models.values()
|
||||
if m.is_trained and self.weights.get(
|
||||
next(k for k, v in self._models.items() if v is m), 0
|
||||
) > 0
|
||||
)
|
||||
return trained >= 2
|
||||
|
||||
def get_status(self) -> Dict:
|
||||
"""Statut de chaque composant + poids actifs."""
|
||||
status = {
|
||||
'ready': self.is_ready(),
|
||||
'min_confidence': self.min_confidence,
|
||||
'require_agreement': self.require_agreement,
|
||||
'weights': dict(self.weights),
|
||||
'models': {},
|
||||
}
|
||||
for name, model in self._models.items():
|
||||
status['models'][name] = {
|
||||
'attached': True,
|
||||
'is_trained': model.is_trained,
|
||||
'weight': self.weights.get(name, 0.0),
|
||||
}
|
||||
# Modèles non attachés mais présents dans les poids
|
||||
for name in self.weights:
|
||||
if name not in status['models']:
|
||||
status['models'][name] = {
|
||||
'attached': False,
|
||||
'is_trained': False,
|
||||
'weight': self.weights[name],
|
||||
}
|
||||
return status
|
||||
|
||||
def update_weights(self, weights: Dict[str, float]) -> None:
|
||||
"""
|
||||
Mise à jour dynamique des poids.
|
||||
|
||||
Si la somme != 1.0, normalise automatiquement et log un warning.
|
||||
"""
|
||||
total = sum(weights.values())
|
||||
if total <= 0:
|
||||
raise ValueError("La somme des poids doit être > 0")
|
||||
|
||||
if abs(total - 1.0) > 1e-6:
|
||||
logger.warning(
|
||||
f"Somme des poids = {total:.4f} != 1.0 — normalisation automatique"
|
||||
)
|
||||
weights = {k: v / total for k, v in weights.items()}
|
||||
|
||||
self.weights.update(weights)
|
||||
logger.info(f"Poids mis à jour : {self.weights}")
|
||||
Reference in New Issue
Block a user