Files
trader-ml/src/ml/ensemble/ensemble_model.py
Tika acc3338213 feat: Phase 4c — CNN + Ensemble architecture (multi-signal trading)
## Nouveaux modules

### src/ml/cnn/
- candlestick_encoder.py : CandlestickEncoder, fenêtres OHLCV z-score (N, 64, 5)
- cnn_model.py : TradingCNN — 3 blocs Conv1D(5→32→64→128) + BN + ReLU + GlobalAvgPool
- cnn_strategy_model.py : CNNStrategyModel, API identique à MLStrategyModel (train/predict/save/load)

### src/ml/ensemble/
- ensemble_model.py : EnsembleModel, poids {xgboost:0.40, cnn:0.60}, accord requis entre modèles

### src/strategies/cnn_driven/
- cnn_strategy.py : CNNDrivenStrategy(BaseStrategy), SL/TP ATR-based, fallback CNN_AVAILABLE=False

### src/strategies/ensemble/
- ensemble_strategy.py : EnsembleStrategy(BaseStrategy), auto-load XGBoost + CNN au démarrage

## Modifications

- trading.py : routes POST /train-cnn, GET /train-cnn/{job_id}, GET /cnn-models,
  POST /ensemble/configure, GET /ensemble/status + fix bugs (logging, _get_data_service, period_map)
- strategy_engine.py : support 'ml_driven' dans load_strategy()
- docker/requirements/api.txt : ajout torch>=2.0.0 + dépendances ML (scikit-learn, xgboost, lightgbm)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-10 19:34:41 +00:00

253 lines
9.4 KiB
Python

"""
Ensemble Model — Combine plusieurs modèles ML pour un signal de trading robuste.
L'EnsembleModel agrège les prédictions de modèles indépendants (XGBoost, CNN,
et plus tard RL) via une moyenne pondérée. Un signal n'est émis que si les
modèles actifs sont en accord ET que le score pondéré dépasse un seuil.
Duck typing : ce module n'importe PAS directement MLStrategyModel ni
CNNStrategyModel. Tout objet exposant `.predict(df)` → dict et `.is_trained`
→ bool est compatible.
"""
import logging
from typing import Any, Dict, Optional
import pandas as pd
logger = logging.getLogger(__name__)
class EnsembleModel:
"""
Combine plusieurs modèles ML pour produire un signal de trading robuste.
Logique :
- Chaque modèle prédit indépendamment (signal + confidence)
- Score final = somme pondérée des confidences pour les modèles en accord
- Signal validé uniquement si :
1. Au moins 2 modèles actifs sont en accord sur la direction
2. Score pondéré >= min_confidence
Poids par défaut : xgboost=0.40, cnn=0.60 (CNN légèrement favorisé car
il voit les données brutes sans biais de feature engineering)
"""
DEFAULT_WEIGHTS = {
'xgboost': 0.40,
'cnn': 0.60,
'rl': 0.00, # Réservé Phase 4d
}
def __init__(
self,
weights: Optional[Dict[str, float]] = None,
min_confidence: float = 0.60,
require_agreement: bool = True,
):
self.weights = dict(weights) if weights else dict(self.DEFAULT_WEIGHTS)
self.min_confidence = min_confidence
self.require_agreement = require_agreement
# Modèles attachés (duck typing : .predict(df), .is_trained)
self._models: Dict[str, Any] = {}
logger.info(
f"EnsembleModel initialisé — poids={self.weights}, "
f"seuil={self.min_confidence}, accord_requis={self.require_agreement}"
)
# ------------------------------------------------------------------
# Attachement des modèles
# ------------------------------------------------------------------
def attach_xgboost(self, model) -> None:
"""Attache un MLStrategyModel entraîné."""
self._attach('xgboost', model)
def attach_cnn(self, model) -> None:
"""Attache un CNNStrategyModel entraîné."""
self._attach('cnn', model)
def attach_rl(self, model) -> None:
"""Attache un agent RL (Phase 4d)."""
self._attach('rl', model)
def _attach(self, name: str, model) -> None:
"""Attache un modèle générique avec vérification duck typing."""
if not hasattr(model, 'predict') or not callable(model.predict):
raise ValueError(f"Le modèle '{name}' doit exposer une méthode predict()")
if not hasattr(model, 'is_trained'):
raise ValueError(f"Le modèle '{name}' doit exposer un attribut is_trained")
self._models[name] = model
# Ajouter le poids par défaut s'il n'existe pas
if name not in self.weights:
self.weights[name] = 0.0
logger.warning(f"Poids pour '{name}' non défini — initialisé à 0.0")
logger.info(f"Modèle '{name}' attaché (is_trained={model.is_trained})")
# ------------------------------------------------------------------
# Prédiction combinée
# ------------------------------------------------------------------
def predict(self, df: pd.DataFrame) -> Dict:
"""
Prédit le signal combiné à partir de tous les modèles actifs.
Returns:
{
'signal': int, # 1 LONG, -1 SHORT, 0 NEUTRAL
'confidence': float, # score pondéré [0..1]
'tradeable': bool,
'agreement': bool, # True si tous les modèles actifs concordent
'components': dict, # résultats individuels par modèle
}
"""
components: Dict[str, Dict] = {}
neutral_result = {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'agreement': False, 'components': components,
}
# 1. Collecter les prédictions des modèles disponibles et entraînés
for name, model in self._models.items():
if not model.is_trained:
logger.debug(f"Ensemble : modèle '{name}' non entraîné, ignoré")
continue
if self.weights.get(name, 0.0) <= 0.0:
logger.debug(f"Ensemble : modèle '{name}' poids=0, ignoré")
continue
try:
result = model.predict(df)
components[name] = {
'signal': result.get('signal', 0),
'confidence': result.get('confidence', 0.0),
}
except Exception as e:
logger.warning(f"Ensemble : erreur predict '{name}'{e}")
continue
if not components:
logger.debug("Ensemble : aucun modèle actif n'a produit de prédiction")
neutral_result['components'] = components
return neutral_result
# 2. Filtrer les signaux non-neutres
directional = {
k: v for k, v in components.items() if v['signal'] != 0
}
if not directional:
# Tous les modèles sont neutres
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'agreement': True, 'components': components,
}
# 3. Vérifier l'accord entre modèles directionnels
directions = set(v['signal'] for v in directional.values())
agreement = len(directions) == 1
if self.require_agreement and not agreement:
logger.debug(
f"Ensemble : désaccord entre modèles — {directional}"
)
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'agreement': False, 'components': components,
}
# 4. Vérifier qu'au moins 2 modèles actifs sont en accord
if len(directional) < 2:
logger.debug("Ensemble : un seul modèle directionnel, signal insuffisant")
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'agreement': True, 'components': components,
}
# 5. Calculer le score pondéré (normalisé sur les modèles actifs)
consensus_dir = directions.pop() # direction unique
total_weight = sum(self.weights.get(k, 0.0) for k in directional)
if total_weight <= 0:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'agreement': agreement, 'components': components,
}
weighted_score = sum(
self.weights.get(k, 0.0) * v['confidence']
for k, v in directional.items()
) / total_weight
# 6. Signal final
tradeable = weighted_score >= self.min_confidence
logger.info(
f"Ensemble : direction={'LONG' if consensus_dir == 1 else 'SHORT'} | "
f"score={weighted_score:.2%} | accord={agreement} | tradeable={tradeable}"
)
return {
'signal': consensus_dir,
'confidence': weighted_score,
'tradeable': tradeable,
'agreement': agreement,
'components': components,
}
# ------------------------------------------------------------------
# Statut et configuration
# ------------------------------------------------------------------
def is_ready(self) -> bool:
"""True si au moins 2 modèles sont attachés et entraînés."""
trained = sum(
1 for m in self._models.values()
if m.is_trained and self.weights.get(
next(k for k, v in self._models.items() if v is m), 0
) > 0
)
return trained >= 2
def get_status(self) -> Dict:
"""Statut de chaque composant + poids actifs."""
status = {
'ready': self.is_ready(),
'min_confidence': self.min_confidence,
'require_agreement': self.require_agreement,
'weights': dict(self.weights),
'models': {},
}
for name, model in self._models.items():
status['models'][name] = {
'attached': True,
'is_trained': model.is_trained,
'weight': self.weights.get(name, 0.0),
}
# Modèles non attachés mais présents dans les poids
for name in self.weights:
if name not in status['models']:
status['models'][name] = {
'attached': False,
'is_trained': False,
'weight': self.weights[name],
}
return status
def update_weights(self, weights: Dict[str, float]) -> None:
"""
Mise à jour dynamique des poids.
Si la somme != 1.0, normalise automatiquement et log un warning.
"""
total = sum(weights.values())
if total <= 0:
raise ValueError("La somme des poids doit être > 0")
if abs(total - 1.0) > 1e-6:
logger.warning(
f"Somme des poids = {total:.4f} != 1.0 — normalisation automatique"
)
weights = {k: v / total for k, v in weights.items()}
self.weights.update(weights)
logger.info(f"Poids mis à jour : {self.weights}")