diff --git a/docker/requirements/api.txt b/docker/requirements/api.txt index 4efc7b8..9a8a71b 100644 --- a/docker/requirements/api.txt +++ b/docker/requirements/api.txt @@ -29,3 +29,7 @@ joblib>=1.3.0 # ML — Deep Learning (CNN pour patterns chandeliers) torch>=2.0.0 + +# Visualisation — graphiques financiers et traitement d'images +mplfinance>=0.12.10b0 +Pillow>=10.0.0 diff --git a/src/api/routers/trading.py b/src/api/routers/trading.py index c2a60e3..1cdf677 100644 --- a/src/api/routers/trading.py +++ b/src/api/routers/trading.py @@ -1230,3 +1230,191 @@ async def get_ensemble_status(): status["ensemble_info"] = {"error": "Impossible de récupérer le statut"} return status + + +# ============================================================================= +# CNN IMAGE STRATEGY — Entraînement et gestion des modèles CNN-Image +# ============================================================================= + +try: + from src.ml.cnn_image import CNNImageStrategyModel + CNN_IMAGE_AVAILABLE = True +except ImportError: + CNN_IMAGE_AVAILABLE = False + +# Stockage en mémoire des jobs d'entraînement CNN Image +_cnn_image_train_jobs: Dict[str, dict] = {} + + +class CNNImageTrainRequest(BaseModel): + """Requête d'entraînement du modèle CNN Image (Conv2D vision).""" + symbol: str = "EURUSD" + timeframe: str = "1h" + period: str = "2y" + seq_len: int = 64 + tp_atr_mult: float = 2.0 + sl_atr_mult: float = 1.0 + horizon: int = 30 + min_confidence: float = 0.55 + + +class CNNImageTrainResponse(BaseModel): + """Réponse d'un job d'entraînement CNN Image.""" + job_id: str + status: str + symbol: str + timeframe: str + wf_accuracy: Optional[float] = None + wf_precision: Optional[float] = None + label_dist: Optional[dict] = None + n_samples: Optional[int] = None + trained_at: Optional[str] = None + error: Optional[str] = None + + +async def _run_cnn_image_train_task(job_id: str, request: CNNImageTrainRequest) -> None: + """Tâche d'entraînement CNN Image exécutée en arrière-plan.""" + _cnn_image_train_jobs[job_id]["status"] = "running" + try: + from src.data.data_service import DataService + from src.utils.config_loader import ConfigLoader + from datetime import timedelta + + config = ConfigLoader.load_all() + data_service = DataService(config) + + end_date = datetime.now() + period_map = {'y': 365, 'm': 30, 'd': 1} + unit = request.period[-1] + value = int(request.period[:-1]) + start_date = end_date - timedelta(days=value * period_map.get(unit, 1)) + + df = await data_service.get_historical_data( + symbol = request.symbol, + timeframe = request.timeframe, + start_date = start_date, + end_date = end_date, + ) + if df is None or len(df) < 200: + raise ValueError( + f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)" + ) + + # Entraînement dans un thread (opération CPU-bound / GPU-bound) + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, _sync_cnn_image_train, df, request) + + _cnn_image_train_jobs[job_id].update({ + "status": "completed", + "symbol": request.symbol, + "timeframe": request.timeframe, + "n_samples": result.get("n_samples"), + "wf_accuracy": result.get("wf_metrics", {}).get("avg_accuracy"), + "wf_precision": result.get("wf_metrics", {}).get("avg_precision"), + "label_dist": result.get("label_dist"), + "trained_at": result.get("trained_at"), + }) + + # Auto-attachement à la stratégie cnn_image_driven active si elle existe + _attach_cnn_image_model_to_strategy(request) + + except Exception as exc: + logger.error(f"Erreur entraînement CNN Image job {job_id} : {exc}", exc_info=True) + _cnn_image_train_jobs[job_id]["status"] = "failed" + _cnn_image_train_jobs[job_id]["error"] = str(exc) + + +def _sync_cnn_image_train(df, request: CNNImageTrainRequest) -> dict: + """Wrapper synchrone pour CNNImageStrategyModel.train() (exécuté dans un thread).""" + from src.ml.cnn_image import CNNImageStrategyModel + model = CNNImageStrategyModel( + symbol = request.symbol, + timeframe = request.timeframe, + seq_len = request.seq_len, + tp_atr_mult = request.tp_atr_mult, + sl_atr_mult = request.sl_atr_mult, + horizon = request.horizon, + min_confidence = request.min_confidence, + ) + return model.train(df) + + +def _attach_cnn_image_model_to_strategy(request: CNNImageTrainRequest) -> None: + """Attache le modèle CNN Image entraîné à la stratégie cnn_image_driven active (paper trading).""" + try: + from src.ml.cnn_image import CNNImageStrategyModel + from src.strategies.cnn_image_driven import CNNImageDrivenStrategy + + engine = _paper_state.get("engine") + if engine and hasattr(engine, 'strategy_engine'): + strat = engine.strategy_engine.strategies.get('cnn_image_driven') + if strat and isinstance(strat, CNNImageDrivenStrategy): + model = CNNImageStrategyModel.load(request.symbol, request.timeframe) + strat.attach_model(model) + logger.info("Modèle CNN Image attaché à la stratégie cnn_image_driven active") + except Exception as e: + logger.debug(f"Auto-attach modèle CNN Image ignoré : {e}") + + +@router.post("/train-cnn-image", response_model=CNNImageTrainResponse, + summary="Entraîner le modèle CNN Image (Conv2D vision)") +async def train_cnn_image(request: CNNImageTrainRequest, background_tasks: BackgroundTasks): + """ + Lance l'entraînement du modèle CNN Image en arrière-plan. + + Le CNN Image Conv2D apprend les patterns visuels des graphiques de chandeliers + (bougies marteau, double top/bottom, rebonds S/R, momentum...) depuis des + images 128×128 RGB générées par mplfinance — aucune feature pré-calculée. + + - Retourne un `job_id` à interroger via `GET /trading/train-cnn-image/{job_id}` + - Le modèle est sauvegardé sur disque après entraînement + - Si un paper trading cnn_image_driven est actif, le modèle lui est automatiquement attaché + """ + if not CNN_IMAGE_AVAILABLE: + raise HTTPException( + 503, + detail="PyTorch + mplfinance requis — rebuilder le container trading-api" + ) + + job_id = str(uuid.uuid4()) + _cnn_image_train_jobs[job_id] = { + "status": "pending", + "symbol": request.symbol, + "timeframe": request.timeframe, + } + + background_tasks.add_task(_run_cnn_image_train_task, job_id, request) + + return CNNImageTrainResponse( + job_id = job_id, + status = "pending", + symbol = request.symbol, + timeframe = request.timeframe, + ) + + +@router.get("/train-cnn-image/{job_id}", response_model=CNNImageTrainResponse, + summary="Résultat entraînement CNN Image") +def get_cnn_image_train_status(job_id: str): + """Retourne l'état d'un job d'entraînement CNN Image.""" + job = _cnn_image_train_jobs.get(job_id) + if job is None: + raise HTTPException(404, detail=f"Job {job_id} introuvable") + return CNNImageTrainResponse(job_id=job_id, **job) + + +@router.get("/cnn-image-models", summary="Liste des modèles CNN Image entraînés") +def list_cnn_image_models(): + """ + Retourne la liste de tous les modèles CNN Image disponibles sur disque, + avec leurs métriques (accuracy, date d'entraînement, nombre de samples...). + """ + if not CNN_IMAGE_AVAILABLE: + return { + "error": "PyTorch + mplfinance requis — rebuilder le container trading-api", + "models": [], + "count": 0, + } + from src.ml.cnn_image import CNNImageStrategyModel + models = CNNImageStrategyModel.list_trained_models() + return {"models": models, "count": len(models)} diff --git a/src/ml/cnn_image/__init__.py b/src/ml/cnn_image/__init__.py new file mode 100644 index 0000000..fe508d2 --- /dev/null +++ b/src/ml/cnn_image/__init__.py @@ -0,0 +1,11 @@ +""" +Module CNN Image-Based — Analyse visuelle de graphiques en chandeliers. + +Ce module implémente un CNN Conv2D qui analyse des graphiques de chandeliers +rendus comme de vraies images (128×128 RGB), pour reconnaître les patterns +visuels exactement comme un trader humain devant TradingView. +""" + +from src.ml.cnn_image.cnn_image_strategy_model import CNNImageStrategyModel + +__all__ = ['CNNImageStrategyModel'] diff --git a/src/ml/cnn_image/chart_renderer.py b/src/ml/cnn_image/chart_renderer.py new file mode 100644 index 0000000..edff582 --- /dev/null +++ b/src/ml/cnn_image/chart_renderer.py @@ -0,0 +1,322 @@ +""" +CandlestickImageRenderer — Convertit des données OHLCV en images de graphiques. + +Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB +qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns. + +Rendu : +- Fond noir (#0d1117), style TradingView +- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse +- Volume en bas de l'image (via mplfinance) +- Pas d'axes, pas de labels, pas de titre +- Taille fixe : 128×128 pixels, 3 canaux RGB + +Si mplfinance ou PIL ne sont pas disponibles, un rendu de fallback basique +encode les données OHLCV numériquement sous forme d'image 2D. +""" + +import io +import logging +from typing import Optional + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + +# --- Détection optionnelle de mplfinance et PIL --- +try: + import mplfinance as mpf + import matplotlib + matplotlib.use('Agg') # Backend non-interactif (pas d'écran requis) + import matplotlib.pyplot as plt + MPLFINANCE_AVAILABLE = True + logger.debug("mplfinance disponible — rendu haute qualité activé") +except ImportError: + MPLFINANCE_AVAILABLE = False + logger.warning("mplfinance non disponible — utilisation du rendu fallback") + +try: + from PIL import Image + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False + logger.warning("Pillow (PIL) non disponible — rendu fallback uniquement") + + +# Taille cible des images en pixels +IMAGE_SIZE = 128 + + +class CandlestickImageRenderer: + """ + Convertit des données OHLCV en images de graphiques en chandeliers. + + Chaque image est un instantané visuel de `seq_len` bougies consécutives, + rendu avec mplfinance (style fond noir, bougies colorées, volume en bas). + Le résultat est normalisé en float32 dans [0, 1]. + + Args: + image_size: Taille carrée de l'image en pixels (défaut 128) + """ + + def __init__(self, image_size: int = IMAGE_SIZE): + self.image_size = image_size + + # ------------------------------------------------------------------------- + # Interface publique + # ------------------------------------------------------------------------- + + def encode(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray: + """ + Encode toutes les fenêtres glissantes du DataFrame en images. + + Produit N = len(df) - seq_len fenêtres, chacune rendue en image + 128×128 RGB normalisée [0, 1]. + + Args: + df: DataFrame OHLCV avec colonnes open/high/low/close/volume + seq_len: Nombre de bougies par fenêtre (défaut 64) + + Returns: + np.ndarray de forme (N, 3, 128, 128), dtype float32, valeurs [0, 1] + Retourne un tableau vide (0, 3, 128, 128) si df est trop court. + """ + df = self._prepare_df(df) + n_windows = len(df) - seq_len + + if n_windows <= 0: + logger.warning( + f"DataFrame trop court ({len(df)} barres) pour seq_len={seq_len}" + ) + return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32) + + images = np.zeros( + (n_windows, 3, self.image_size, self.image_size), dtype=np.float32 + ) + + for i in range(n_windows): + window = df.iloc[i: i + seq_len] + try: + img = self._render_single(window) + images[i] = img + except Exception as e: + # Fenêtre problématique → on laisse des zéros pour cette position + logger.debug(f"Erreur rendu fenêtre {i} : {e}") + + return images + + def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray: + """ + Encode uniquement la dernière fenêtre du DataFrame. + + Utilisé pour la prédiction en temps réel : retourne (1, 3, 128, 128). + + Args: + df: DataFrame OHLCV + seq_len: Nombre de bougies pour la dernière fenêtre (défaut 64) + + Returns: + np.ndarray de forme (1, 3, 128, 128), dtype float32, valeurs [0, 1] + + Raises: + ValueError si df est trop court pour seq_len bougies + """ + df = self._prepare_df(df) + + if len(df) < seq_len: + raise ValueError( + f"DataFrame insuffisant : {len(df)} barres < seq_len={seq_len}" + ) + + window = df.iloc[-seq_len:] + img = self._render_single(window) + # Ajouter dimension batch + return img[np.newaxis, ...] # (1, 3, H, W) + + # ------------------------------------------------------------------------- + # Rendu d'une fenêtre + # ------------------------------------------------------------------------- + + def _render_single(self, df_window: pd.DataFrame) -> np.ndarray: + """ + Rend une fenêtre OHLCV en image numpy (3, 128, 128). + + Utilise mplfinance si disponible, sinon fallback encodage numérique. + + Args: + df_window: DataFrame OHLCV pour une fenêtre (seq_len barres) + + Returns: + np.ndarray de forme (3, 128, 128), float32, valeurs [0, 1] + """ + if MPLFINANCE_AVAILABLE and PIL_AVAILABLE: + return self._render_with_mplfinance(df_window) + else: + return self._render_fallback(df_window) + + def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray: + """ + Rendu haute qualité via mplfinance. + + Style : fond noir (#0d1117), bougies vertes/rouges, volume en bas, + aucun axe ni label ni titre. Image 128×128 RGB. + """ + # Style personnalisé : fond noir, bougies colorées TradingView + style = mpf.make_mpf_style( + base_mpf_style='nightclouds', + marketcolors=mpf.make_marketcolors( + up='#26a69a', # vert TradingView (hausse) + down='#ef5350', # rouge TradingView (baisse) + wick={'up': '#26a69a', 'down': '#ef5350'}, + edge={'up': '#26a69a', 'down': '#ef5350'}, + volume={'up': '#26a69a', 'down': '#ef5350'}, + ), + facecolor='#0d1117', # Fond noir GitHub-style + figcolor='#0d1117', + gridcolor='#0d1117', + ) + + # Rendu en mémoire via BytesIO + buf = io.BytesIO() + + fig, axes = mpf.plot( + df_window, + type='candle', + style=style, + volume=True, + axisoff=True, # Pas d'axes + tight_layout=True, + returnfig=True, + figsize=(1.28, 1.28), # 128 DPI × 1.28 inch = 128 pixels + ) + + # Suppression de tous les éléments décoratifs + for ax in axes: + ax.set_axis_off() + ax.set_facecolor('#0d1117') + for spine in ax.spines.values(): + spine.set_visible(False) + + fig.savefig( + buf, + format='png', + dpi=100, + bbox_inches='tight', + pad_inches=0, + facecolor='#0d1117', + ) + plt.close(fig) + + buf.seek(0) + pil_img = Image.open(buf).convert('RGB') + pil_img = pil_img.resize( + (self.image_size, self.image_size), Image.LANCZOS + ) + + # Conversion en array numpy (H, W, 3) → (3, H, W), float32, [0,1] + arr = np.array(pil_img, dtype=np.float32) / 255.0 + arr = arr.transpose(2, 0, 1) # HWC → CHW + + return arr + + def _render_fallback(self, df_window: pd.DataFrame) -> np.ndarray: + """ + Rendu de secours sans mplfinance : encode les données OHLCV + directement en image 2D normalisée. + + Chaque colonne de l'image correspond à une bougie : + - Canal 0 (R) : position relative du close dans le range [low, high] + - Canal 1 (G) : amplitude de la bougie (high - low normalisé) + - Canal 2 (B) : volume normalisé + + L'image est ensuite redimensionnée à image_size × image_size. + """ + cols = df_window[['open', 'high', 'low', 'close', 'volume']].copy() + n = len(cols) + + if n == 0: + return np.zeros((3, self.image_size, self.image_size), dtype=np.float32) + + # Normalisation par colonne + highs = cols['high'].values.astype(np.float32) + lows = cols['low'].values.astype(np.float32) + closes = cols['close'].values.astype(np.float32) + opens = cols['open'].values.astype(np.float32) + vols = cols['volume'].values.astype(np.float32) + + price_range = highs - lows + price_range = np.where(price_range == 0, 1e-8, price_range) + + # Canal R : position du close dans le range de la bougie [0, 1] + close_pos = (closes - lows) / price_range + + # Canal G : corps de la bougie (|close - open| / range) + body = np.abs(closes - opens) / price_range + + # Canal B : volume normalisé [0, 1] + vol_max = vols.max() + vol_norm = vols / (vol_max if vol_max > 0 else 1.0) + + # Construction image : 3 × 1 × n → 3 × image_size × image_size + # On crée une image de height=image_size, width=n puis on redimensionne + img = np.zeros((3, 1, n), dtype=np.float32) + img[0, 0, :] = close_pos + img[1, 0, :] = body + img[2, 0, :] = vol_norm + + # Redimensionnement vers (3, image_size, image_size) via répétition + # On étire chaque canal sur les deux dimensions + img_resized = np.zeros( + (3, self.image_size, self.image_size), dtype=np.float32 + ) + for c in range(3): + # Répétition en hauteur (axe 0 du canal) et interpolation en largeur + channel = img[c, 0, :] # (n,) + # Interpolation 1D vers image_size + x_orig = np.linspace(0, 1, n) + x_new = np.linspace(0, 1, self.image_size) + channel_resized = np.interp(x_new, x_orig, channel).astype(np.float32) + # Étirer sur toute la hauteur + img_resized[c] = np.tile(channel_resized, (self.image_size, 1)) + + return img_resized + + # ------------------------------------------------------------------------- + # Utilitaires + # ------------------------------------------------------------------------- + + @staticmethod + def _prepare_df(df: pd.DataFrame) -> pd.DataFrame: + """ + Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex. + + Args: + df: DataFrame OHLCV brut + + Returns: + DataFrame nettoyé avec colonnes open/high/low/close/volume + """ + df = df.copy() + df.columns = [c.lower() for c in df.columns] + + # S'assurer que l'index est un DatetimeIndex (requis par mplfinance) + if not isinstance(df.index, pd.DatetimeIndex): + try: + df.index = pd.to_datetime(df.index) + except Exception: + # Créer un index artificiel si la conversion échoue + df.index = pd.date_range( + start='2020-01-01', periods=len(df), freq='1h' + ) + + # Conserver uniquement les colonnes OHLCV + required = ['open', 'high', 'low', 'close', 'volume'] + for col in required: + if col not in df.columns: + df[col] = 0.0 + + df = df[required].dropna(subset=['open', 'high', 'low', 'close']) + df = df.ffill().bfill() + + return df diff --git a/src/ml/cnn_image/cnn_image_model.py b/src/ml/cnn_image/cnn_image_model.py new file mode 100644 index 0000000..c096b6f --- /dev/null +++ b/src/ml/cnn_image/cnn_image_model.py @@ -0,0 +1,163 @@ +""" +CandlestickCNN — Réseau de neurones convolutif pour l'analyse visuelle de graphiques. + +Architecture Conv2D 4-blocs conçue pour reconnaître les patterns visuels dans +des images de chandeliers 128×128 RGB (bougies, volumes, supports/résistances). + +Input : (batch, 3, 128, 128) — images RGB normalisées [0, 1] +Output : (batch, 3) — logits pour 3 classes : SHORT(0), NEUTRAL(1), LONG(2) + +Classes encodées : + 0 → SHORT (-1) + 1 → NEUTRAL ( 0) + 2 → LONG (+1) + +L'encodage +1 est cohérent avec MLStrategyModel et CNNStrategyModel. +""" + +import logging +from typing import Optional + +import numpy as np + +logger = logging.getLogger(__name__) + +# --- Détection optionnelle de PyTorch --- +try: + import torch + import torch.nn as nn + import torch.nn.functional as F + TORCH_AVAILABLE = True + logger.debug("PyTorch disponible — CandlestickCNN activé") +except ImportError: + TORCH_AVAILABLE = False + # Stubs pour permettre l'import du module sans PyTorch + nn = None + torch = None + logger.warning("PyTorch non disponible — CandlestickCNN désactivé") + + +# ───────────────────────────────────────────────────────────────────────────── +# Définition conditionnelle du modèle +# ───────────────────────────────────────────────────────────────────────────── + +if TORCH_AVAILABLE: + + class CandlestickCNN(nn.Module): + """ + CNN vision pour graphiques chandeliers 128×128 RGB. + + Architecture : + Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm2d(32) + ReLU + MaxPool2d(2) → (32, 64, 64) + Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm2d(64) + ReLU + MaxPool2d(2) → (64, 32, 32) + Bloc 3 : Conv2d(64→128,3×3) + BatchNorm2d(128) + ReLU + MaxPool2d(2) → (128, 16, 16) + Bloc 4 : Conv2d(128→256,3×3)+ BatchNorm2d(256) + ReLU + AdaptiveAvgPool2d(1) → (256,) + + Classifieur : + Linear(256→128) + Dropout(0.4) + ReLU + Linear(128→3) + + Output : logits bruts (3 classes) — appliquer softmax pour les probabilités. + + Args: + n_classes: Nombre de classes de sortie (défaut 3 : SHORT/NEUTRAL/LONG) + dropout: Taux de dropout dans le classifieur (défaut 0.4) + """ + + def __init__(self, n_classes: int = 3, dropout: float = 0.4): + super(CandlestickCNN, self).__init__() + + # --- Bloc convolutif 1 : (3, 128, 128) → (32, 64, 64) --- + self.bloc1 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + # --- Bloc convolutif 2 : (32, 64, 64) → (64, 32, 32) --- + self.bloc2 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + # --- Bloc convolutif 3 : (64, 32, 32) → (128, 16, 16) --- + self.bloc3 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + # --- Bloc convolutif 4 : (128, 16, 16) → (256, 1, 1) → (256,) --- + self.bloc4 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.AdaptiveAvgPool2d(output_size=1), # Global average pooling → (256, 1, 1) + ) + + # --- Classifieur dense --- + self.classifieur = nn.Sequential( + nn.Linear(256, 128), + nn.Dropout(p=dropout), + nn.ReLU(inplace=True), + nn.Linear(128, n_classes), + ) + + def forward(self, x: 'torch.Tensor') -> 'torch.Tensor': + """ + Passe avant du réseau. + + Args: + x: Tensor (batch, 3, 128, 128) — images RGB normalisées + + Returns: + Tensor (batch, 3) — logits bruts + """ + x = self.bloc1(x) # (batch, 32, 64, 64) + x = self.bloc2(x) # (batch, 64, 32, 32) + x = self.bloc3(x) # (batch, 128, 16, 16) + x = self.bloc4(x) # (batch, 256, 1, 1) + x = x.flatten(1) # (batch, 256) + x = self.classifieur(x) # (batch, 3) + return x + + def predict_proba(self, x: 'torch.Tensor') -> np.ndarray: + """ + Calcule les probabilités softmax pour chaque classe. + + Args: + x: Tensor (batch, 3, 128, 128) — images RGB normalisées + + Returns: + np.ndarray de forme (batch, 3), valeurs dans [0, 1], somme = 1 + Ordre des colonnes : [P_SHORT, P_NEUTRAL, P_LONG] + """ + self.eval() + with torch.no_grad(): + logits = self.forward(x) # (batch, 3) + probas = F.softmax(logits, dim=1) # (batch, 3) + return probas.cpu().numpy().astype(np.float32) + +else: + # Stub si PyTorch n'est pas disponible + class CandlestickCNN: + """ + Stub de CandlestickCNN — PyTorch non disponible. + + Toutes les méthodes lèvent une RuntimeError explicite. + """ + + def __init__(self, n_classes: int = 3, dropout: float = 0.4): + raise RuntimeError( + "PyTorch non disponible. Installer torch>=2.0.0 pour utiliser CandlestickCNN." + ) + + def forward(self, x): + raise RuntimeError("PyTorch non disponible.") + + def predict_proba(self, x): + raise RuntimeError("PyTorch non disponible.") diff --git a/src/ml/cnn_image/cnn_image_strategy_model.py b/src/ml/cnn_image/cnn_image_strategy_model.py new file mode 100644 index 0000000..c2c40e0 --- /dev/null +++ b/src/ml/cnn_image/cnn_image_strategy_model.py @@ -0,0 +1,654 @@ +""" +CNNImageStrategyModel — Modèle CNN image-based pour la prédiction de signaux de trading. + +Ce module entraîne un CNN Conv2D (CandlestickCNN) sur des images de graphiques +en chandeliers pour prédire LONG (+1), NEUTRAL (0) ou SHORT (-1). + +Pipeline : + 1. Génération de labels ATR-based (LabelGenerator partagé) + 2. Rendu des fenêtres OHLCV en images 128×128 RGB (CandlestickImageRenderer) + 3. Walk-forward (2 folds temporels) pour évaluation out-of-sample + 4. Entraînement sur toutes les données (Adam, CrossEntropyLoss weighted) + 5. Sauvegarde : EURUSD_1h.pt (state_dict) + EURUSD_1h_meta.json + +Interface identique à MLStrategyModel et CNNStrategyModel : + model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h') + result = model.train(df_ohlcv) + signal = model.predict(df) # {signal, confidence, probas, tradeable} + model.save() + model.load(symbol, timeframe) + model.list_trained_models() + model.get_feature_importance() # retourne [] (CNN = boîte noire) +""" + +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd + +from src.ml.features.label_generator import LabelGenerator +from src.ml.cnn_image.chart_renderer import CandlestickImageRenderer, MPLFINANCE_AVAILABLE +from src.ml.cnn_image.cnn_image_model import CandlestickCNN, TORCH_AVAILABLE + +logger = logging.getLogger(__name__) + +# Répertoire de sauvegarde des modèles CNN image +MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_image_strategy" + +# Décodage des classes encodées vers les signaux de trading +# Classe 0 → SHORT (-1), Classe 1 → NEUTRAL (0), Classe 2 → LONG (+1) +CLASS_MAP = {0: -1, 1: 0, 2: 1} + +# Imports conditionnels PyTorch +if TORCH_AVAILABLE: + import torch + import torch.nn as nn + from torch.utils.data import DataLoader, TensorDataset + from sklearn.model_selection import TimeSeriesSplit + from sklearn.metrics import precision_recall_fscore_support + SKLEARN_AVAILABLE = True +else: + torch = None + nn = None + DataLoader = None + TensorDataset = None + TimeSeriesSplit = None + precision_recall_fscore_support = None + SKLEARN_AVAILABLE = False + +try: + from sklearn.model_selection import TimeSeriesSplit + from sklearn.metrics import precision_recall_fscore_support + SKLEARN_AVAILABLE = True +except ImportError: + SKLEARN_AVAILABLE = False + + +class CNNImageStrategyModel: + """ + Modèle CNN image-based qui reconnaît les patterns visuels de trading. + + Le modèle : + - Convertit des bougies OHLCV en images 128×128 RGB (mplfinance) + - Reconnaît visuellement : marteaux, doji, double top/bottom, supports/résistances + - Prédit LONG (1) / SHORT (-1) / NEUTRAL (0) + - Donne un score de confiance [0..1] par prédiction + - Se sauvegarde en format PyTorch (.pt) pour rechargement sans ré-entraînement + + Args: + symbol: Paire tradée (ex: 'EURUSD') + timeframe: Timeframe (ex: '1h', '15m') + seq_len: Nombre de bougies par image (défaut 64) + tp_atr_mult: Multiplicateur ATR pour le TP (génération labels) + sl_atr_mult: Multiplicateur ATR pour le SL (génération labels) + horizon: Nombre de barres pour évaluer TP/SL + min_confidence: Seuil de confiance pour le signal tradeable + epochs: Nombre maximum d'époques d'entraînement + batch_size: Taille du batch (défaut 32) + lr: Taux d'apprentissage Adam (défaut 1e-3) + patience: Patience pour l'early stopping (défaut 7) + """ + + MODELS_DIR = MODELS_DIR + CLASS_MAP = CLASS_MAP + + def __init__( + self, + symbol: str = 'EURUSD', + timeframe: str = '1h', + seq_len: int = 64, + tp_atr_mult: float = 2.0, + sl_atr_mult: float = 1.0, + horizon: int = 30, + min_confidence: float = 0.55, + epochs: int = 50, + batch_size: int = 32, + lr: float = 1e-3, + patience: int = 7, + ): + self.symbol = symbol + self.timeframe = timeframe + self.seq_len = seq_len + self.tp_atr_mult = tp_atr_mult + self.sl_atr_mult = sl_atr_mult + self.horizon = horizon + self.min_confidence = min_confidence + self.epochs = epochs + self.batch_size = batch_size + self.lr = lr + self.patience = patience + + self.model: Optional['CandlestickCNN'] = None + self.is_trained = False + self.metadata: Dict = {} + + self.MODELS_DIR.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------------- + # Entraînement + # ------------------------------------------------------------------------- + + def train(self, data: pd.DataFrame) -> Dict: + """ + Entraîne le modèle CNN image sur les données OHLCV fournies. + + Étapes : + 1. Génération des labels ATR-based (LabelGenerator) + 2. Encodage des fenêtres en images (CandlestickImageRenderer) + 3. Walk-forward évaluation (2 folds temporels) + 4. Entraînement final sur toutes les données + 5. Sauvegarde automatique + + Args: + data: DataFrame OHLCV (minimum 200 barres recommandées, + colonnes : open/high/low/close/volume) + + Returns: + Dict avec métriques : wf_accuracy, wf_precision, label_dist, + n_samples, trained_at, et error si échec. + """ + # Vérifications de disponibilité + if not TORCH_AVAILABLE: + return {'error': 'PyTorch non disponible — installer torch>=2.0.0'} + if not SKLEARN_AVAILABLE: + return {'error': 'scikit-learn non disponible'} + + logger.info( + f"Début entraînement CNNImageStrategyModel pour " + f"{self.symbol}/{self.timeframe} — seq_len={self.seq_len}" + ) + + # 1. Génération des labels (ATR-based, cohérent avec MLStrategyModel) + gen = LabelGenerator(horizon=self.horizon) + labels = gen.generate_atr_based( + data, + atr_tp_mult=self.tp_atr_mult, + atr_sl_mult=self.sl_atr_mult, + ) + + # Encodage [-1,0,1] → [0,1,2] pour CrossEntropyLoss + y_enc = (labels + 1).values # np.ndarray + + # 2. Encodage des images (fenêtres glissantes) + logger.info(f" Rendu des images (seq_len={self.seq_len})...") + renderer = CandlestickImageRenderer() + X = renderer.encode(data, seq_len=self.seq_len) # (N, 3, 128, 128) + + # 3. Alignement X et y + # encode() produit N = len(data) - seq_len images + # La i-ème image correspond à data.iloc[i : i + seq_len] + # Le label associé est y_enc[i + seq_len - 1] (dernier point de la fenêtre) + n_images = len(X) + if n_images == 0: + return { + 'error': f'Pas assez de données : {len(data)} barres < seq_len={self.seq_len}' + } + + # Index des labels pour chaque image : fenêtre i → label à l'indice i + seq_len - 1 + label_indices = np.arange(self.seq_len - 1, self.seq_len - 1 + n_images) + # Supprimer les barres de fin (labels non fiables, horizon tronqué) + valid_mask = label_indices < (len(y_enc) - self.horizon) + + X = X[valid_mask] + y = y_enc[label_indices[valid_mask]] + + n_samples = len(X) + if n_samples < 50: + return {'error': f'Trop peu d\'échantillons valides : {n_samples}'} + + logger.info( + f" {n_samples} images valides — " + f"LONG={(y==2).sum()}, SHORT={(y==0).sum()}, NEUTRAL={(y==1).sum()}" + ) + + # 4. Calcul des class_weights (inversement proportionnels à la fréquence) + class_weights = self._compute_class_weights(y) + + # 5. Walk-forward évaluation (2 folds) + wf_metrics = self._walk_forward_eval(X, y, class_weights, n_splits=2) + + # 6. Entraînement final sur toutes les données + logger.info(" Entraînement final sur toutes les données...") + self.model = CandlestickCNN(n_classes=3) + self._train_model(self.model, X, y, class_weights) + self.is_trained = True + + # Distribution des labels (décodage pour lisibilité) + label_dist = { + 'long': int((y == 2).sum()), + 'short': int((y == 0).sum()), + 'neutral': int((y == 1).sum()), + } + + self.metadata = { + 'symbol': self.symbol, + 'timeframe': self.timeframe, + 'seq_len': self.seq_len, + 'trained_at': datetime.utcnow().isoformat(), + 'n_samples': n_samples, + 'tp_atr_mult': self.tp_atr_mult, + 'sl_atr_mult': self.sl_atr_mult, + 'horizon': self.horizon, + 'label_dist': label_dist, + 'wf_metrics': wf_metrics, + 'mplfinance': MPLFINANCE_AVAILABLE, + } + + # 7. Sauvegarde + self.save() + + logger.info( + f"Entraînement terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}" + ) + return self.metadata + + # ------------------------------------------------------------------------- + # Prédiction + # ------------------------------------------------------------------------- + + def predict(self, data: pd.DataFrame) -> Dict: + """ + Prédit le signal pour la dernière fenêtre disponible. + + Args: + data: DataFrame OHLCV récent (minimum seq_len barres) + + Returns: + Dict : { + 'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL), + 'confidence': float [0..1] — max(P_LONG, P_SHORT), + 'probas': {'long': float, 'short': float, 'neutral': float}, + 'tradeable': bool — confidence >= min_confidence et signal != 0, + } + """ + if not TORCH_AVAILABLE: + return { + 'signal': 0, 'confidence': 0.0, 'tradeable': False, + 'error': 'PyTorch non disponible' + } + if not self.is_trained or self.model is None: + return { + 'signal': 0, 'confidence': 0.0, 'tradeable': False, + 'error': 'Modèle non entraîné' + } + + try: + renderer = CandlestickImageRenderer() + img = renderer.encode_last(data, seq_len=self.seq_len) # (1, 3, 128, 128) + except ValueError as e: + return { + 'signal': 0, 'confidence': 0.0, 'tradeable': False, + 'error': str(e) + } + + # Conversion en Tensor + x_tensor = torch.from_numpy(img).float() # (1, 3, 128, 128) + + # Prédiction + proba_arr = self.model.predict_proba(x_tensor) # (1, 3) + proba = proba_arr[0] # (3,) : [P_SHORT, P_NEUTRAL, P_LONG] + + pred_class = int(np.argmax(proba)) + signal = self.CLASS_MAP[pred_class] # Décodage : 0→-1, 1→0, 2→1 + + probas = { + 'short': float(proba[0]), + 'neutral': float(proba[1]), + 'long': float(proba[2]), + } + confidence = float(max(probas['long'], probas['short'])) + + return { + 'signal': signal, + 'confidence': confidence, + 'probas': probas, + 'tradeable': confidence >= self.min_confidence and signal != 0, + } + + # ------------------------------------------------------------------------- + # Sauvegarde / Chargement + # ------------------------------------------------------------------------- + + def save(self) -> Path: + """ + Sauvegarde le modèle et ses métadonnées sur disque. + + Format : + {symbol}_{timeframe}.pt — state_dict PyTorch + {symbol}_{timeframe}_meta.json — métadonnées JSON + + Returns: + Path vers le fichier .pt sauvegardé + + Raises: + RuntimeError si le modèle n'est pas entraîné ou PyTorch indisponible + """ + if not TORCH_AVAILABLE: + raise RuntimeError("PyTorch non disponible") + if not self.is_trained or self.model is None: + raise RuntimeError("Modèle non entraîné — appeler train() avant save()") + + model_id = f"{self.symbol}_{self.timeframe}" + model_path = self.MODELS_DIR / f"{model_id}.pt" + meta_path = self.MODELS_DIR / f"{model_id}_meta.json" + + # Sauvegarde du state_dict + config pour pouvoir reconstruire le modèle + torch.save( + { + 'state_dict': self.model.state_dict(), + 'config': { + 'symbol': self.symbol, + 'timeframe': self.timeframe, + 'seq_len': self.seq_len, + 'tp_atr_mult': self.tp_atr_mult, + 'sl_atr_mult': self.sl_atr_mult, + 'horizon': self.horizon, + 'min_confidence': self.min_confidence, + 'epochs': self.epochs, + 'batch_size': self.batch_size, + 'lr': self.lr, + 'patience': self.patience, + }, + 'metadata': self.metadata, + }, + model_path + ) + + with open(meta_path, 'w') as f: + json.dump(self.metadata, f, indent=2, default=str) + + logger.info(f"Modèle CNN image sauvegardé : {model_path}") + return model_path + + @classmethod + def load(cls, symbol: str, timeframe: str) -> 'CNNImageStrategyModel': + """ + Charge un modèle existant depuis le disque. + + Args: + symbol: Paire (ex: 'EURUSD') + timeframe: Timeframe (ex: '1h') + + Returns: + Instance CNNImageStrategyModel prête à prédire + + Raises: + RuntimeError si PyTorch non disponible + FileNotFoundError si le modèle n'existe pas + """ + if not TORCH_AVAILABLE: + raise RuntimeError("PyTorch non disponible") + + model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt" + + if not model_path.exists(): + raise FileNotFoundError(f"Modèle non trouvé : {model_path}") + + checkpoint = torch.load(model_path, map_location='cpu') + cfg = checkpoint.get('config', {}) + + instance = cls( + symbol = cfg.get('symbol', symbol), + timeframe = cfg.get('timeframe', timeframe), + seq_len = cfg.get('seq_len', 64), + tp_atr_mult = cfg.get('tp_atr_mult', 2.0), + sl_atr_mult = cfg.get('sl_atr_mult', 1.0), + horizon = cfg.get('horizon', 30), + min_confidence = cfg.get('min_confidence', 0.55), + epochs = cfg.get('epochs', 50), + batch_size = cfg.get('batch_size', 32), + lr = cfg.get('lr', 1e-3), + patience = cfg.get('patience', 7), + ) + + # Reconstruction du modèle et chargement des poids + instance.model = CandlestickCNN(n_classes=3) + instance.model.load_state_dict(checkpoint['state_dict']) + instance.model.eval() + instance.is_trained = True + instance.metadata = checkpoint.get('metadata', {}) + + logger.info(f"Modèle CNN image chargé depuis {model_path}") + return instance + + @staticmethod + def list_trained_models() -> List[Dict]: + """ + Retourne la liste des modèles entraînés disponibles. + + Returns: + Liste de dicts avec symbol, timeframe, trained_at, n_samples, wf_accuracy + """ + if not MODELS_DIR.exists(): + return [] + + models = [] + for f in MODELS_DIR.glob("*_meta.json"): + try: + with open(f) as fp: + meta = json.load(fp) + # Nom du fichier : EURUSD_1h_meta.json → EURUSD, 1h + stem = f.stem.replace('_meta', '') # ex: EURUSD_1h + parts = stem.split('_', 1) + models.append({ + 'symbol': meta.get('symbol', parts[0] if parts else '?'), + 'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'), + 'trained_at': meta.get('trained_at', '?'), + 'n_samples': meta.get('n_samples', 0), + 'seq_len': meta.get('seq_len', 64), + 'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0), + 'mplfinance': meta.get('mplfinance', False), + }) + except Exception: + pass + + return models + + def get_feature_importance(self) -> List[Dict]: + """ + Retourne l'importance des features. + + Note : Les CNN sont des boîtes noires — pas d'importance de feature + interprétable comme XGBoost. Retourne une liste vide. + + Returns: + [] — CNN = boîte noire visuelle + """ + return [] + + # ------------------------------------------------------------------------- + # Entraînement interne + # ------------------------------------------------------------------------- + + def _train_model( + self, + model: 'CandlestickCNN', + X: np.ndarray, + y: np.ndarray, + class_weights: 'torch.Tensor', + ) -> None: + """ + Entraîne le modèle CNN sur les données fournies avec early stopping. + + Args: + model: Instance CandlestickCNN à entraîner + X: Images (N, 3, 128, 128) float32 + y: Labels encodés (N,) int, valeurs {0, 1, 2} + class_weights: Poids de classes (3,) pour CrossEntropyLoss + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + class_weights = class_weights.to(device) + + # Conversion en Tensors + X_t = torch.from_numpy(X).float() + y_t = torch.from_numpy(y.astype(np.int64)) + + # Split train/validation (80/20 temporel) + n_train = int(len(X_t) * 0.8) + X_tr, X_val = X_t[:n_train], X_t[n_train:] + y_tr, y_val = y_t[:n_train], y_t[n_train:] + + # DataLoaders + train_ds = TensorDataset(X_tr, y_tr) + val_ds = TensorDataset(X_val, y_val) + train_dl = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False) + val_dl = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False) + + # Optimiseur et critère + optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) + criterion = nn.CrossEntropyLoss(weight=class_weights) + + best_val_loss = float('inf') + patience_count = 0 + best_state = None + + for epoch in range(self.epochs): + # --- Phase entraînement --- + model.train() + train_loss = 0.0 + for X_batch, y_batch in train_dl: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + optimizer.zero_grad() + logits = model(X_batch) + loss = criterion(logits, y_batch) + loss.backward() + optimizer.step() + train_loss += loss.item() * len(X_batch) + train_loss /= max(len(train_ds), 1) + + # --- Phase validation --- + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for X_batch, y_batch in val_dl: + X_batch = X_batch.to(device) + y_batch = y_batch.to(device) + logits = model(X_batch) + loss = criterion(logits, y_batch) + val_loss += loss.item() * len(X_batch) + val_loss /= max(len(val_ds), 1) + + logger.debug( + f" Epoch {epoch+1}/{self.epochs} — " + f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}" + ) + + # Early stopping + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_count = 0 + best_state = { + k: v.cpu().clone() for k, v in model.state_dict().items() + } + else: + patience_count += 1 + if patience_count >= self.patience: + logger.info( + f" Early stopping à l'époque {epoch+1} " + f"(patience={self.patience}, best_val_loss={best_val_loss:.4f})" + ) + break + + # Restauration des meilleurs poids + if best_state is not None: + model.load_state_dict(best_state) + + model.eval() + model.to('cpu') + + def _walk_forward_eval( + self, + X: np.ndarray, + y: np.ndarray, + class_weights: 'torch.Tensor', + n_splits: int = 2, + ) -> Dict: + """ + Évalue le modèle en cross-validation temporelle (walk-forward). + + Args: + X: Images (N, 3, 128, 128) + y: Labels encodés (N,) + class_weights: Poids de classes pour CrossEntropyLoss + n_splits: Nombre de folds temporels + + Returns: + Dict : avg_accuracy, avg_precision, fold_accuracies + """ + tscv = TimeSeriesSplit(n_splits=n_splits) + accuracies = [] + precisions = [] + + for fold, (train_idx, test_idx) in enumerate(tscv.split(X)): + X_tr, X_te = X[train_idx], X[test_idx] + y_tr, y_te = y[train_idx], y[test_idx] + + # Modèle temporaire pour ce fold + fold_model = CandlestickCNN(n_classes=3) + self._train_model(fold_model, X_tr, y_tr, class_weights) + + # Évaluation + x_tensor = torch.from_numpy(X_te).float() + proba_arr = fold_model.predict_proba(x_tensor) # (N_te, 3) + y_pred = np.argmax(proba_arr, axis=1) + + acc = float((y_pred == y_te).mean()) + accuracies.append(acc) + + # Précision sur signaux directionnels (LONG=2 et SHORT=0, exclure NEUTRAL=1) + mask = (y_te != 1) | (y_pred != 1) + if mask.sum() > 0 and SKLEARN_AVAILABLE: + prec, _, _, _ = precision_recall_fscore_support( + y_te[mask], y_pred[mask], + average='macro', zero_division=0 + ) + precisions.append(float(prec)) + else: + precisions.append(0.0) + + logger.info( + f" Fold {fold+1}/{n_splits} : acc={acc:.2%}, " + f"prec={precisions[-1]:.2%}" + ) + + return { + 'avg_accuracy': float(np.mean(accuracies)) if accuracies else 0.0, + 'avg_precision': float(np.mean(precisions)) if precisions else 0.0, + 'fold_accuracies': [float(a) for a in accuracies], + } + + @staticmethod + def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor': + """ + Calcule les poids de classes inversement proportionnels à leur fréquence. + + Permet de compenser les déséquilibres (ex: beaucoup de NEUTRAL, peu de trades). + + Args: + y: Labels encodés {0, 1, 2} + + Returns: + torch.Tensor de forme (3,) — poids pour CrossEntropyLoss + """ + weights = np.ones(3, dtype=np.float32) + n_total = len(y) + + if n_total == 0: + return torch.from_numpy(weights) + + for cls in range(3): + count = (y == cls).sum() + if count > 0: + # Poids inversement proportionnel : classes rares → poids plus élevé + weights[cls] = n_total / (3.0 * count) + + # Normalisation pour que la somme des poids = 3 + weights = weights / weights.mean() + + return torch.from_numpy(weights) diff --git a/src/ml/ensemble/ensemble_model.py b/src/ml/ensemble/ensemble_model.py index 320bbc1..aa41874 100644 --- a/src/ml/ensemble/ensemble_model.py +++ b/src/ml/ensemble/ensemble_model.py @@ -34,9 +34,10 @@ class EnsembleModel: """ DEFAULT_WEIGHTS = { - 'xgboost': 0.40, - 'cnn': 0.60, - 'rl': 0.00, # Réservé Phase 4d + 'xgboost': 0.30, + 'cnn': 0.30, + 'cnn_image': 0.40, # CNN Vision — favorisé (patterns visuels bruts) + 'rl': 0.00, # Réservé Phase 4d } def __init__( @@ -52,6 +53,9 @@ class EnsembleModel: # Modèles attachés (duck typing : .predict(df), .is_trained) self._models: Dict[str, Any] = {} + # Référence directe au modèle CNN Image (accès rapide pour auto-attach) + self._cnn_image_model = None + logger.info( f"EnsembleModel initialisé — poids={self.weights}, " f"seuil={self.min_confidence}, accord_requis={self.require_agreement}" @@ -68,6 +72,11 @@ class EnsembleModel: """Attache un CNNStrategyModel entraîné.""" self._attach('cnn', model) + def attach_cnn_image(self, model) -> None: + """Attache un CNNImageStrategyModel entraîné (Phase 4c-bis).""" + self._cnn_image_model = model + self._attach('cnn_image', model) + def attach_rl(self, model) -> None: """Attache un agent RL (Phase 4d).""" self._attach('rl', model) @@ -198,12 +207,13 @@ class EnsembleModel: # Statut et configuration # ------------------------------------------------------------------ def is_ready(self) -> bool: - """True si au moins 2 modèles sont attachés et entraînés.""" + """True si au moins 2 modèles parmi {xgboost, cnn, cnn_image} sont attachés et entraînés.""" + eligible_names = {'xgboost', 'cnn', 'cnn_image'} trained = sum( - 1 for m in self._models.values() - if m.is_trained and self.weights.get( - next(k for k, v in self._models.items() if v is m), 0 - ) > 0 + 1 for name, model in self._models.items() + if name in eligible_names + and model.is_trained + and self.weights.get(name, 0.0) > 0 ) return trained >= 2 diff --git a/src/strategies/cnn_image_driven/__init__.py b/src/strategies/cnn_image_driven/__init__.py new file mode 100644 index 0000000..161041f --- /dev/null +++ b/src/strategies/cnn_image_driven/__init__.py @@ -0,0 +1,3 @@ +from src.strategies.cnn_image_driven.cnn_image_strategy import CNNImageDrivenStrategy + +__all__ = ['CNNImageDrivenStrategy'] diff --git a/src/strategies/cnn_image_driven/cnn_image_strategy.py b/src/strategies/cnn_image_driven/cnn_image_strategy.py new file mode 100644 index 0000000..31b2b44 --- /dev/null +++ b/src/strategies/cnn_image_driven/cnn_image_strategy.py @@ -0,0 +1,258 @@ +""" +CNN Image-Driven Strategy — Stratégie pilotée par un CNN à vision par ordinateur. + +Contrairement au CNN 1D qui analyse des séquences numériques brutes, cette +stratégie convertit les bougies OHLCV en images de graphiques (rendu mplfinance +128×128 RGB), puis utilise un réseau Conv2D pour reconnaître les patterns +visuels — exactement comme un trader devant TradingView. + +Patterns détectés sans feature engineering explicite : + - Bougies marteau, étoile filante, doji + - Double top, double bottom + - Rebonds sur supports/résistances visuels + - Momentum : grande bougie pleine après consolidation + - Divergences prix/volume visibles + +Fonctionnement : + 1. Le modèle CNN-Image est chargé depuis le disque (entraîné via POST /trading/train-cnn-image) + 2. À chaque barre, les seq_len dernières bougies sont rendues en image 128×128 + 3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance + 4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR + +Intégration : + - Compatible avec StrategyEngine (même interface que CNNDrivenStrategy) + - Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe + - Requiert PyTorch + mplfinance + Pillow +""" + +import logging +from datetime import datetime, timezone +from typing import Dict, Optional + +import numpy as np +import pandas as pd + +from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig + +try: + from src.ml.cnn_image import CNNImageStrategyModel + CNN_IMAGE_AVAILABLE = True +except ImportError: + CNN_IMAGE_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +class CNNImageDrivenStrategy(BaseStrategy): + """ + Stratégie de trading pilotée par un CNN Conv2D pré-entraîné sur images de graphiques. + + Le modèle apprend les patterns visuels des chandeliers directement depuis + des images 128×128 RGB sans features pré-calculées : + - Double bottom / double top + - Configurations chandeliers (marteau, doji, étoile filante) + - Rebonds sur zones de support/résistance visibles + - Structures de momentum multi-barres + - Divergences volume/prix visuelles + + Args: + config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.) + + Config keys supplémentaires (optionnelles) : + min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55) + tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0) + sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0) + seq_len: Nombre de bougies par image (défaut: 64) + auto_load: Charger automatiquement le modèle existant (défaut: True) + """ + + STRATEGY_NAME = 'cnn_image_driven' + + def __init__(self, config: Dict): + super().__init__(config) + + self.symbol = config.get('symbol', 'EURUSD') + self.min_confidence = config.get('min_confidence', 0.55) + self.tp_atr_mult = config.get('tp_atr_mult', 2.0) + self.sl_atr_mult = config.get('sl_atr_mult', 1.0) + self.seq_len = config.get('seq_len', 64) + + self.cnn_image_model: Optional['CNNImageStrategyModel'] = None + + if not CNN_IMAGE_AVAILABLE: + logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)") + return + + # Tentative de chargement automatique du modèle existant + if config.get('auto_load', True): + self._try_load_model() + + # ------------------------------------------------------------------------- + # Interface BaseStrategy + # ------------------------------------------------------------------------- + def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]: + """ + Génère un signal de trading via le modèle CNN Image. + + Args: + market_data: DataFrame OHLCV (minimum seq_len barres) + + Returns: + Signal si le modèle est confiant, None sinon + """ + if not CNN_IMAGE_AVAILABLE: + logger.debug("CNN Image Strategy : PyTorch/mplfinance non disponible, aucun signal") + return None + + if self.cnn_image_model is None or not self.cnn_image_model.is_trained: + logger.debug("CNN Image Strategy : modèle non chargé, aucun signal") + return None + + if len(market_data) < self.seq_len: + return None + + try: + result = self.cnn_image_model.predict(market_data) + except Exception as e: + logger.warning(f"CNN Image Strategy predict error : {e}") + return None + + if not result.get('tradeable', False): + return None + + signal_dir = result['signal'] # 1 = LONG, -1 = SHORT + confidence = result['confidence'] + + # Prix et ATR pour SL/TP + last_close = float(market_data['close'].iloc[-1]) + atr = self._compute_atr(market_data) + if atr <= 0: + return None + + if signal_dir == 1: + direction = 'LONG' + stop_loss = last_close - self.sl_atr_mult * atr + take_profit = last_close + self.tp_atr_mult * atr + elif signal_dir == -1: + direction = 'SHORT' + stop_loss = last_close + self.sl_atr_mult * atr + take_profit = last_close - self.tp_atr_mult * atr + else: + return None + + signal = Signal( + symbol = self.symbol, + direction = direction, + entry_price = last_close, + stop_loss = stop_loss, + take_profit = take_profit, + confidence = confidence, + timestamp = datetime.now(timezone.utc), + strategy = self.STRATEGY_NAME, + metadata = { + 'probas': result.get('probas', {}), + 'seq_len': self.seq_len, + 'atr': atr, + 'tp_atr_mult': self.tp_atr_mult, + 'sl_atr_mult': self.sl_atr_mult, + }, + ) + + logger.info( + f"CNN Image Signal : {direction} {self.symbol} | " + f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | " + f"confidence={confidence:.2%}" + ) + return signal + + def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame: + """Retourne les données telles quelles — le CNN Image travaille sur des rendus d'images.""" + return data + + def update_params(self, params: Dict) -> None: + """Mise à jour dynamique des paramètres (depuis API ou Optuna).""" + if 'min_confidence' in params: + self.min_confidence = params['min_confidence'] + if self.cnn_image_model: + self.cnn_image_model.min_confidence = params['min_confidence'] + if 'tp_atr_mult' in params: + self.tp_atr_mult = params['tp_atr_mult'] + if 'sl_atr_mult' in params: + self.sl_atr_mult = params['sl_atr_mult'] + if 'seq_len' in params: + self.seq_len = params['seq_len'] + logger.info(f"CNN Image Strategy params mis à jour : {params}") + + # ------------------------------------------------------------------------- + # Gestion du modèle + # ------------------------------------------------------------------------- + def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool: + """ + Charge un modèle CNN Image depuis le disque. + + Args: + symbol: Paire (défaut: self.symbol) + timeframe: Timeframe (défaut: self.config.timeframe) + + Returns: + True si chargement réussi + """ + if not CNN_IMAGE_AVAILABLE: + logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)") + return False + + sym = symbol or self.symbol + tf = timeframe or self.config.timeframe + try: + self.cnn_image_model = CNNImageStrategyModel.load(sym, tf) + logger.info(f"Modèle CNN Image chargé : {sym}/{tf}") + return True + except FileNotFoundError: + logger.info(f"Aucun modèle CNN Image trouvé pour {sym}/{tf}") + return False + except Exception as e: + logger.error(f"Erreur chargement modèle CNN Image : {e}") + return False + + def attach_model(self, model: 'CNNImageStrategyModel') -> None: + """Attache directement un modèle CNN Image (après entraînement via API).""" + self.cnn_image_model = model + self.symbol = model.symbol + logger.info(f"Modèle CNN Image attaché : {model.symbol}/{model.timeframe}") + + def is_ready(self) -> bool: + """Retourne True si le modèle CNN Image est chargé et entraîné.""" + if not CNN_IMAGE_AVAILABLE: + return False + return self.cnn_image_model is not None and self.cnn_image_model.is_trained + + def get_model_info(self) -> Dict: + """Retourne les métadonnées du modèle CNN Image actif.""" + if not CNN_IMAGE_AVAILABLE: + return {'status': 'PyTorch/mplfinance non disponible'} + if not self.is_ready(): + return {'status': 'non entraîné'} + meta = self.cnn_image_model.metadata.copy() + meta['is_ready'] = True + meta['seq_len'] = self.seq_len + return meta + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + def _try_load_model(self) -> None: + """Tente un chargement silencieux du modèle au démarrage.""" + try: + self.load_model() + except Exception: + pass + + @staticmethod + def _compute_atr(df: pd.DataFrame, period: int = 14) -> float: + """Calcule l'ATR moyen sur les dernières barres.""" + if len(df) < period + 1: + return float(df['high'].iloc[-1] - df['low'].iloc[-1]) + h, l, pc = df['high'], df['low'], df['close'].shift(1) + tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1) + atr = tr.rolling(period).mean().iloc[-1] + return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])