feat: Phase 4c-bis — CNN image-based (analyse visuelle graphiques chandeliers)

## Nouveaux modules

### src/ml/cnn_image/
- chart_renderer.py : CandlestickImageRenderer — OHLCV → images 128×128 RGB (mplfinance)
  Fond #0d1117, bougies vertes/rouges, volume, sans axes, rendu en mémoire
  Fallback 2D si mplfinance absent
- cnn_image_model.py : CandlestickCNN — Conv2D 4-blocs (3→32→64→128→256) + AvgPool + Dense(3)
- cnn_image_strategy_model.py : CNNImageStrategyModel — même interface que MLStrategyModel

### src/strategies/cnn_image_driven/
- cnn_image_strategy.py : CNNImageDrivenStrategy(BaseStrategy), SL/TP ATR, seq_len=64

## Modifications

- ensemble_model.py : attach_cnn_image(), poids XGB=0.30/CNN1D=0.30/CNNImage=0.40
- trading.py : POST /train-cnn-image, GET /train-cnn-image/{id}, GET /cnn-image-models
- docker/requirements/api.txt : mplfinance>=0.12.10b0, Pillow>=10.0.0

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Tika
2026-03-10 20:22:41 +00:00
parent e9d4c440d9
commit 80e1308a1e
9 changed files with 1621 additions and 8 deletions

View File

@@ -29,3 +29,7 @@ joblib>=1.3.0
# ML — Deep Learning (CNN pour patterns chandeliers)
torch>=2.0.0
# Visualisation — graphiques financiers et traitement d'images
mplfinance>=0.12.10b0
Pillow>=10.0.0

View File

@@ -1230,3 +1230,191 @@ async def get_ensemble_status():
status["ensemble_info"] = {"error": "Impossible de récupérer le statut"}
return status
# =============================================================================
# CNN IMAGE STRATEGY — Entraînement et gestion des modèles CNN-Image
# =============================================================================
try:
from src.ml.cnn_image import CNNImageStrategyModel
CNN_IMAGE_AVAILABLE = True
except ImportError:
CNN_IMAGE_AVAILABLE = False
# Stockage en mémoire des jobs d'entraînement CNN Image
_cnn_image_train_jobs: Dict[str, dict] = {}
class CNNImageTrainRequest(BaseModel):
"""Requête d'entraînement du modèle CNN Image (Conv2D vision)."""
symbol: str = "EURUSD"
timeframe: str = "1h"
period: str = "2y"
seq_len: int = 64
tp_atr_mult: float = 2.0
sl_atr_mult: float = 1.0
horizon: int = 30
min_confidence: float = 0.55
class CNNImageTrainResponse(BaseModel):
"""Réponse d'un job d'entraînement CNN Image."""
job_id: str
status: str
symbol: str
timeframe: str
wf_accuracy: Optional[float] = None
wf_precision: Optional[float] = None
label_dist: Optional[dict] = None
n_samples: Optional[int] = None
trained_at: Optional[str] = None
error: Optional[str] = None
async def _run_cnn_image_train_task(job_id: str, request: CNNImageTrainRequest) -> None:
"""Tâche d'entraînement CNN Image exécutée en arrière-plan."""
_cnn_image_train_jobs[job_id]["status"] = "running"
try:
from src.data.data_service import DataService
from src.utils.config_loader import ConfigLoader
from datetime import timedelta
config = ConfigLoader.load_all()
data_service = DataService(config)
end_date = datetime.now()
period_map = {'y': 365, 'm': 30, 'd': 1}
unit = request.period[-1]
value = int(request.period[:-1])
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
df = await data_service.get_historical_data(
symbol = request.symbol,
timeframe = request.timeframe,
start_date = start_date,
end_date = end_date,
)
if df is None or len(df) < 200:
raise ValueError(
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
)
# Entraînement dans un thread (opération CPU-bound / GPU-bound)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, _sync_cnn_image_train, df, request)
_cnn_image_train_jobs[job_id].update({
"status": "completed",
"symbol": request.symbol,
"timeframe": request.timeframe,
"n_samples": result.get("n_samples"),
"wf_accuracy": result.get("wf_metrics", {}).get("avg_accuracy"),
"wf_precision": result.get("wf_metrics", {}).get("avg_precision"),
"label_dist": result.get("label_dist"),
"trained_at": result.get("trained_at"),
})
# Auto-attachement à la stratégie cnn_image_driven active si elle existe
_attach_cnn_image_model_to_strategy(request)
except Exception as exc:
logger.error(f"Erreur entraînement CNN Image job {job_id} : {exc}", exc_info=True)
_cnn_image_train_jobs[job_id]["status"] = "failed"
_cnn_image_train_jobs[job_id]["error"] = str(exc)
def _sync_cnn_image_train(df, request: CNNImageTrainRequest) -> dict:
"""Wrapper synchrone pour CNNImageStrategyModel.train() (exécuté dans un thread)."""
from src.ml.cnn_image import CNNImageStrategyModel
model = CNNImageStrategyModel(
symbol = request.symbol,
timeframe = request.timeframe,
seq_len = request.seq_len,
tp_atr_mult = request.tp_atr_mult,
sl_atr_mult = request.sl_atr_mult,
horizon = request.horizon,
min_confidence = request.min_confidence,
)
return model.train(df)
def _attach_cnn_image_model_to_strategy(request: CNNImageTrainRequest) -> None:
"""Attache le modèle CNN Image entraîné à la stratégie cnn_image_driven active (paper trading)."""
try:
from src.ml.cnn_image import CNNImageStrategyModel
from src.strategies.cnn_image_driven import CNNImageDrivenStrategy
engine = _paper_state.get("engine")
if engine and hasattr(engine, 'strategy_engine'):
strat = engine.strategy_engine.strategies.get('cnn_image_driven')
if strat and isinstance(strat, CNNImageDrivenStrategy):
model = CNNImageStrategyModel.load(request.symbol, request.timeframe)
strat.attach_model(model)
logger.info("Modèle CNN Image attaché à la stratégie cnn_image_driven active")
except Exception as e:
logger.debug(f"Auto-attach modèle CNN Image ignoré : {e}")
@router.post("/train-cnn-image", response_model=CNNImageTrainResponse,
summary="Entraîner le modèle CNN Image (Conv2D vision)")
async def train_cnn_image(request: CNNImageTrainRequest, background_tasks: BackgroundTasks):
"""
Lance l'entraînement du modèle CNN Image en arrière-plan.
Le CNN Image Conv2D apprend les patterns visuels des graphiques de chandeliers
(bougies marteau, double top/bottom, rebonds S/R, momentum...) depuis des
images 128×128 RGB générées par mplfinance — aucune feature pré-calculée.
- Retourne un `job_id` à interroger via `GET /trading/train-cnn-image/{job_id}`
- Le modèle est sauvegardé sur disque après entraînement
- Si un paper trading cnn_image_driven est actif, le modèle lui est automatiquement attaché
"""
if not CNN_IMAGE_AVAILABLE:
raise HTTPException(
503,
detail="PyTorch + mplfinance requis — rebuilder le container trading-api"
)
job_id = str(uuid.uuid4())
_cnn_image_train_jobs[job_id] = {
"status": "pending",
"symbol": request.symbol,
"timeframe": request.timeframe,
}
background_tasks.add_task(_run_cnn_image_train_task, job_id, request)
return CNNImageTrainResponse(
job_id = job_id,
status = "pending",
symbol = request.symbol,
timeframe = request.timeframe,
)
@router.get("/train-cnn-image/{job_id}", response_model=CNNImageTrainResponse,
summary="Résultat entraînement CNN Image")
def get_cnn_image_train_status(job_id: str):
"""Retourne l'état d'un job d'entraînement CNN Image."""
job = _cnn_image_train_jobs.get(job_id)
if job is None:
raise HTTPException(404, detail=f"Job {job_id} introuvable")
return CNNImageTrainResponse(job_id=job_id, **job)
@router.get("/cnn-image-models", summary="Liste des modèles CNN Image entraînés")
def list_cnn_image_models():
"""
Retourne la liste de tous les modèles CNN Image disponibles sur disque,
avec leurs métriques (accuracy, date d'entraînement, nombre de samples...).
"""
if not CNN_IMAGE_AVAILABLE:
return {
"error": "PyTorch + mplfinance requis — rebuilder le container trading-api",
"models": [],
"count": 0,
}
from src.ml.cnn_image import CNNImageStrategyModel
models = CNNImageStrategyModel.list_trained_models()
return {"models": models, "count": len(models)}

View File

@@ -0,0 +1,11 @@
"""
Module CNN Image-Based — Analyse visuelle de graphiques en chandeliers.
Ce module implémente un CNN Conv2D qui analyse des graphiques de chandeliers
rendus comme de vraies images (128×128 RGB), pour reconnaître les patterns
visuels exactement comme un trader humain devant TradingView.
"""
from src.ml.cnn_image.cnn_image_strategy_model import CNNImageStrategyModel
__all__ = ['CNNImageStrategyModel']

View File

@@ -0,0 +1,322 @@
"""
CandlestickImageRenderer — Convertit des données OHLCV en images de graphiques.
Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB
qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns.
Rendu :
- Fond noir (#0d1117), style TradingView
- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse
- Volume en bas de l'image (via mplfinance)
- Pas d'axes, pas de labels, pas de titre
- Taille fixe : 128×128 pixels, 3 canaux RGB
Si mplfinance ou PIL ne sont pas disponibles, un rendu de fallback basique
encode les données OHLCV numériquement sous forme d'image 2D.
"""
import io
import logging
from typing import Optional
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
# --- Détection optionnelle de mplfinance et PIL ---
try:
import mplfinance as mpf
import matplotlib
matplotlib.use('Agg') # Backend non-interactif (pas d'écran requis)
import matplotlib.pyplot as plt
MPLFINANCE_AVAILABLE = True
logger.debug("mplfinance disponible — rendu haute qualité activé")
except ImportError:
MPLFINANCE_AVAILABLE = False
logger.warning("mplfinance non disponible — utilisation du rendu fallback")
try:
from PIL import Image
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
logger.warning("Pillow (PIL) non disponible — rendu fallback uniquement")
# Taille cible des images en pixels
IMAGE_SIZE = 128
class CandlestickImageRenderer:
"""
Convertit des données OHLCV en images de graphiques en chandeliers.
Chaque image est un instantané visuel de `seq_len` bougies consécutives,
rendu avec mplfinance (style fond noir, bougies colorées, volume en bas).
Le résultat est normalisé en float32 dans [0, 1].
Args:
image_size: Taille carrée de l'image en pixels (défaut 128)
"""
def __init__(self, image_size: int = IMAGE_SIZE):
self.image_size = image_size
# -------------------------------------------------------------------------
# Interface publique
# -------------------------------------------------------------------------
def encode(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
"""
Encode toutes les fenêtres glissantes du DataFrame en images.
Produit N = len(df) - seq_len fenêtres, chacune rendue en image
128×128 RGB normalisée [0, 1].
Args:
df: DataFrame OHLCV avec colonnes open/high/low/close/volume
seq_len: Nombre de bougies par fenêtre (défaut 64)
Returns:
np.ndarray de forme (N, 3, 128, 128), dtype float32, valeurs [0, 1]
Retourne un tableau vide (0, 3, 128, 128) si df est trop court.
"""
df = self._prepare_df(df)
n_windows = len(df) - seq_len
if n_windows <= 0:
logger.warning(
f"DataFrame trop court ({len(df)} barres) pour seq_len={seq_len}"
)
return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32)
images = np.zeros(
(n_windows, 3, self.image_size, self.image_size), dtype=np.float32
)
for i in range(n_windows):
window = df.iloc[i: i + seq_len]
try:
img = self._render_single(window)
images[i] = img
except Exception as e:
# Fenêtre problématique → on laisse des zéros pour cette position
logger.debug(f"Erreur rendu fenêtre {i} : {e}")
return images
def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
"""
Encode uniquement la dernière fenêtre du DataFrame.
Utilisé pour la prédiction en temps réel : retourne (1, 3, 128, 128).
Args:
df: DataFrame OHLCV
seq_len: Nombre de bougies pour la dernière fenêtre (défaut 64)
Returns:
np.ndarray de forme (1, 3, 128, 128), dtype float32, valeurs [0, 1]
Raises:
ValueError si df est trop court pour seq_len bougies
"""
df = self._prepare_df(df)
if len(df) < seq_len:
raise ValueError(
f"DataFrame insuffisant : {len(df)} barres < seq_len={seq_len}"
)
window = df.iloc[-seq_len:]
img = self._render_single(window)
# Ajouter dimension batch
return img[np.newaxis, ...] # (1, 3, H, W)
# -------------------------------------------------------------------------
# Rendu d'une fenêtre
# -------------------------------------------------------------------------
def _render_single(self, df_window: pd.DataFrame) -> np.ndarray:
"""
Rend une fenêtre OHLCV en image numpy (3, 128, 128).
Utilise mplfinance si disponible, sinon fallback encodage numérique.
Args:
df_window: DataFrame OHLCV pour une fenêtre (seq_len barres)
Returns:
np.ndarray de forme (3, 128, 128), float32, valeurs [0, 1]
"""
if MPLFINANCE_AVAILABLE and PIL_AVAILABLE:
return self._render_with_mplfinance(df_window)
else:
return self._render_fallback(df_window)
def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray:
"""
Rendu haute qualité via mplfinance.
Style : fond noir (#0d1117), bougies vertes/rouges, volume en bas,
aucun axe ni label ni titre. Image 128×128 RGB.
"""
# Style personnalisé : fond noir, bougies colorées TradingView
style = mpf.make_mpf_style(
base_mpf_style='nightclouds',
marketcolors=mpf.make_marketcolors(
up='#26a69a', # vert TradingView (hausse)
down='#ef5350', # rouge TradingView (baisse)
wick={'up': '#26a69a', 'down': '#ef5350'},
edge={'up': '#26a69a', 'down': '#ef5350'},
volume={'up': '#26a69a', 'down': '#ef5350'},
),
facecolor='#0d1117', # Fond noir GitHub-style
figcolor='#0d1117',
gridcolor='#0d1117',
)
# Rendu en mémoire via BytesIO
buf = io.BytesIO()
fig, axes = mpf.plot(
df_window,
type='candle',
style=style,
volume=True,
axisoff=True, # Pas d'axes
tight_layout=True,
returnfig=True,
figsize=(1.28, 1.28), # 128 DPI × 1.28 inch = 128 pixels
)
# Suppression de tous les éléments décoratifs
for ax in axes:
ax.set_axis_off()
ax.set_facecolor('#0d1117')
for spine in ax.spines.values():
spine.set_visible(False)
fig.savefig(
buf,
format='png',
dpi=100,
bbox_inches='tight',
pad_inches=0,
facecolor='#0d1117',
)
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf).convert('RGB')
pil_img = pil_img.resize(
(self.image_size, self.image_size), Image.LANCZOS
)
# Conversion en array numpy (H, W, 3) → (3, H, W), float32, [0,1]
arr = np.array(pil_img, dtype=np.float32) / 255.0
arr = arr.transpose(2, 0, 1) # HWC → CHW
return arr
def _render_fallback(self, df_window: pd.DataFrame) -> np.ndarray:
"""
Rendu de secours sans mplfinance : encode les données OHLCV
directement en image 2D normalisée.
Chaque colonne de l'image correspond à une bougie :
- Canal 0 (R) : position relative du close dans le range [low, high]
- Canal 1 (G) : amplitude de la bougie (high - low normalisé)
- Canal 2 (B) : volume normalisé
L'image est ensuite redimensionnée à image_size × image_size.
"""
cols = df_window[['open', 'high', 'low', 'close', 'volume']].copy()
n = len(cols)
if n == 0:
return np.zeros((3, self.image_size, self.image_size), dtype=np.float32)
# Normalisation par colonne
highs = cols['high'].values.astype(np.float32)
lows = cols['low'].values.astype(np.float32)
closes = cols['close'].values.astype(np.float32)
opens = cols['open'].values.astype(np.float32)
vols = cols['volume'].values.astype(np.float32)
price_range = highs - lows
price_range = np.where(price_range == 0, 1e-8, price_range)
# Canal R : position du close dans le range de la bougie [0, 1]
close_pos = (closes - lows) / price_range
# Canal G : corps de la bougie (|close - open| / range)
body = np.abs(closes - opens) / price_range
# Canal B : volume normalisé [0, 1]
vol_max = vols.max()
vol_norm = vols / (vol_max if vol_max > 0 else 1.0)
# Construction image : 3 × 1 × n → 3 × image_size × image_size
# On crée une image de height=image_size, width=n puis on redimensionne
img = np.zeros((3, 1, n), dtype=np.float32)
img[0, 0, :] = close_pos
img[1, 0, :] = body
img[2, 0, :] = vol_norm
# Redimensionnement vers (3, image_size, image_size) via répétition
# On étire chaque canal sur les deux dimensions
img_resized = np.zeros(
(3, self.image_size, self.image_size), dtype=np.float32
)
for c in range(3):
# Répétition en hauteur (axe 0 du canal) et interpolation en largeur
channel = img[c, 0, :] # (n,)
# Interpolation 1D vers image_size
x_orig = np.linspace(0, 1, n)
x_new = np.linspace(0, 1, self.image_size)
channel_resized = np.interp(x_new, x_orig, channel).astype(np.float32)
# Étirer sur toute la hauteur
img_resized[c] = np.tile(channel_resized, (self.image_size, 1))
return img_resized
# -------------------------------------------------------------------------
# Utilitaires
# -------------------------------------------------------------------------
@staticmethod
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
"""
Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex.
Args:
df: DataFrame OHLCV brut
Returns:
DataFrame nettoyé avec colonnes open/high/low/close/volume
"""
df = df.copy()
df.columns = [c.lower() for c in df.columns]
# S'assurer que l'index est un DatetimeIndex (requis par mplfinance)
if not isinstance(df.index, pd.DatetimeIndex):
try:
df.index = pd.to_datetime(df.index)
except Exception:
# Créer un index artificiel si la conversion échoue
df.index = pd.date_range(
start='2020-01-01', periods=len(df), freq='1h'
)
# Conserver uniquement les colonnes OHLCV
required = ['open', 'high', 'low', 'close', 'volume']
for col in required:
if col not in df.columns:
df[col] = 0.0
df = df[required].dropna(subset=['open', 'high', 'low', 'close'])
df = df.ffill().bfill()
return df

View File

@@ -0,0 +1,163 @@
"""
CandlestickCNN — Réseau de neurones convolutif pour l'analyse visuelle de graphiques.
Architecture Conv2D 4-blocs conçue pour reconnaître les patterns visuels dans
des images de chandeliers 128×128 RGB (bougies, volumes, supports/résistances).
Input : (batch, 3, 128, 128) — images RGB normalisées [0, 1]
Output : (batch, 3) — logits pour 3 classes : SHORT(0), NEUTRAL(1), LONG(2)
Classes encodées :
0 → SHORT (-1)
1 → NEUTRAL ( 0)
2 → LONG (+1)
L'encodage +1 est cohérent avec MLStrategyModel et CNNStrategyModel.
"""
import logging
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
# --- Détection optionnelle de PyTorch ---
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
TORCH_AVAILABLE = True
logger.debug("PyTorch disponible — CandlestickCNN activé")
except ImportError:
TORCH_AVAILABLE = False
# Stubs pour permettre l'import du module sans PyTorch
nn = None
torch = None
logger.warning("PyTorch non disponible — CandlestickCNN désactivé")
# ─────────────────────────────────────────────────────────────────────────────
# Définition conditionnelle du modèle
# ─────────────────────────────────────────────────────────────────────────────
if TORCH_AVAILABLE:
class CandlestickCNN(nn.Module):
"""
CNN vision pour graphiques chandeliers 128×128 RGB.
Architecture :
Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm2d(32) + ReLU + MaxPool2d(2) → (32, 64, 64)
Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm2d(64) + ReLU + MaxPool2d(2) → (64, 32, 32)
Bloc 3 : Conv2d(64→128,3×3) + BatchNorm2d(128) + ReLU + MaxPool2d(2) → (128, 16, 16)
Bloc 4 : Conv2d(128→256,3×3)+ BatchNorm2d(256) + ReLU + AdaptiveAvgPool2d(1) → (256,)
Classifieur :
Linear(256→128) + Dropout(0.4) + ReLU
Linear(128→3)
Output : logits bruts (3 classes) — appliquer softmax pour les probabilités.
Args:
n_classes: Nombre de classes de sortie (défaut 3 : SHORT/NEUTRAL/LONG)
dropout: Taux de dropout dans le classifieur (défaut 0.4)
"""
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
super(CandlestickCNN, self).__init__()
# --- Bloc convolutif 1 : (3, 128, 128) → (32, 64, 64) ---
self.bloc1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# --- Bloc convolutif 2 : (32, 64, 64) → (64, 32, 32) ---
self.bloc2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# --- Bloc convolutif 3 : (64, 32, 32) → (128, 16, 16) ---
self.bloc3 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# --- Bloc convolutif 4 : (128, 16, 16) → (256, 1, 1) → (256,) ---
self.bloc4 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=1), # Global average pooling → (256, 1, 1)
)
# --- Classifieur dense ---
self.classifieur = nn.Sequential(
nn.Linear(256, 128),
nn.Dropout(p=dropout),
nn.ReLU(inplace=True),
nn.Linear(128, n_classes),
)
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
"""
Passe avant du réseau.
Args:
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
Returns:
Tensor (batch, 3) — logits bruts
"""
x = self.bloc1(x) # (batch, 32, 64, 64)
x = self.bloc2(x) # (batch, 64, 32, 32)
x = self.bloc3(x) # (batch, 128, 16, 16)
x = self.bloc4(x) # (batch, 256, 1, 1)
x = x.flatten(1) # (batch, 256)
x = self.classifieur(x) # (batch, 3)
return x
def predict_proba(self, x: 'torch.Tensor') -> np.ndarray:
"""
Calcule les probabilités softmax pour chaque classe.
Args:
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
Returns:
np.ndarray de forme (batch, 3), valeurs dans [0, 1], somme = 1
Ordre des colonnes : [P_SHORT, P_NEUTRAL, P_LONG]
"""
self.eval()
with torch.no_grad():
logits = self.forward(x) # (batch, 3)
probas = F.softmax(logits, dim=1) # (batch, 3)
return probas.cpu().numpy().astype(np.float32)
else:
# Stub si PyTorch n'est pas disponible
class CandlestickCNN:
"""
Stub de CandlestickCNN — PyTorch non disponible.
Toutes les méthodes lèvent une RuntimeError explicite.
"""
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
raise RuntimeError(
"PyTorch non disponible. Installer torch>=2.0.0 pour utiliser CandlestickCNN."
)
def forward(self, x):
raise RuntimeError("PyTorch non disponible.")
def predict_proba(self, x):
raise RuntimeError("PyTorch non disponible.")

View File

@@ -0,0 +1,654 @@
"""
CNNImageStrategyModel — Modèle CNN image-based pour la prédiction de signaux de trading.
Ce module entraîne un CNN Conv2D (CandlestickCNN) sur des images de graphiques
en chandeliers pour prédire LONG (+1), NEUTRAL (0) ou SHORT (-1).
Pipeline :
1. Génération de labels ATR-based (LabelGenerator partagé)
2. Rendu des fenêtres OHLCV en images 128×128 RGB (CandlestickImageRenderer)
3. Walk-forward (2 folds temporels) pour évaluation out-of-sample
4. Entraînement sur toutes les données (Adam, CrossEntropyLoss weighted)
5. Sauvegarde : EURUSD_1h.pt (state_dict) + EURUSD_1h_meta.json
Interface identique à MLStrategyModel et CNNStrategyModel :
model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h')
result = model.train(df_ohlcv)
signal = model.predict(df) # {signal, confidence, probas, tradeable}
model.save()
model.load(symbol, timeframe)
model.list_trained_models()
model.get_feature_importance() # retourne [] (CNN = boîte noire)
"""
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
from src.ml.features.label_generator import LabelGenerator
from src.ml.cnn_image.chart_renderer import CandlestickImageRenderer, MPLFINANCE_AVAILABLE
from src.ml.cnn_image.cnn_image_model import CandlestickCNN, TORCH_AVAILABLE
logger = logging.getLogger(__name__)
# Répertoire de sauvegarde des modèles CNN image
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_image_strategy"
# Décodage des classes encodées vers les signaux de trading
# Classe 0 → SHORT (-1), Classe 1 → NEUTRAL (0), Classe 2 → LONG (+1)
CLASS_MAP = {0: -1, 1: 0, 2: 1}
# Imports conditionnels PyTorch
if TORCH_AVAILABLE:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import precision_recall_fscore_support
SKLEARN_AVAILABLE = True
else:
torch = None
nn = None
DataLoader = None
TensorDataset = None
TimeSeriesSplit = None
precision_recall_fscore_support = None
SKLEARN_AVAILABLE = False
try:
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import precision_recall_fscore_support
SKLEARN_AVAILABLE = True
except ImportError:
SKLEARN_AVAILABLE = False
class CNNImageStrategyModel:
"""
Modèle CNN image-based qui reconnaît les patterns visuels de trading.
Le modèle :
- Convertit des bougies OHLCV en images 128×128 RGB (mplfinance)
- Reconnaît visuellement : marteaux, doji, double top/bottom, supports/résistances
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
- Donne un score de confiance [0..1] par prédiction
- Se sauvegarde en format PyTorch (.pt) pour rechargement sans ré-entraînement
Args:
symbol: Paire tradée (ex: 'EURUSD')
timeframe: Timeframe (ex: '1h', '15m')
seq_len: Nombre de bougies par image (défaut 64)
tp_atr_mult: Multiplicateur ATR pour le TP (génération labels)
sl_atr_mult: Multiplicateur ATR pour le SL (génération labels)
horizon: Nombre de barres pour évaluer TP/SL
min_confidence: Seuil de confiance pour le signal tradeable
epochs: Nombre maximum d'époques d'entraînement
batch_size: Taille du batch (défaut 32)
lr: Taux d'apprentissage Adam (défaut 1e-3)
patience: Patience pour l'early stopping (défaut 7)
"""
MODELS_DIR = MODELS_DIR
CLASS_MAP = CLASS_MAP
def __init__(
self,
symbol: str = 'EURUSD',
timeframe: str = '1h',
seq_len: int = 64,
tp_atr_mult: float = 2.0,
sl_atr_mult: float = 1.0,
horizon: int = 30,
min_confidence: float = 0.55,
epochs: int = 50,
batch_size: int = 32,
lr: float = 1e-3,
patience: int = 7,
):
self.symbol = symbol
self.timeframe = timeframe
self.seq_len = seq_len
self.tp_atr_mult = tp_atr_mult
self.sl_atr_mult = sl_atr_mult
self.horizon = horizon
self.min_confidence = min_confidence
self.epochs = epochs
self.batch_size = batch_size
self.lr = lr
self.patience = patience
self.model: Optional['CandlestickCNN'] = None
self.is_trained = False
self.metadata: Dict = {}
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
# -------------------------------------------------------------------------
# Entraînement
# -------------------------------------------------------------------------
def train(self, data: pd.DataFrame) -> Dict:
"""
Entraîne le modèle CNN image sur les données OHLCV fournies.
Étapes :
1. Génération des labels ATR-based (LabelGenerator)
2. Encodage des fenêtres en images (CandlestickImageRenderer)
3. Walk-forward évaluation (2 folds temporels)
4. Entraînement final sur toutes les données
5. Sauvegarde automatique
Args:
data: DataFrame OHLCV (minimum 200 barres recommandées,
colonnes : open/high/low/close/volume)
Returns:
Dict avec métriques : wf_accuracy, wf_precision, label_dist,
n_samples, trained_at, et error si échec.
"""
# Vérifications de disponibilité
if not TORCH_AVAILABLE:
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
if not SKLEARN_AVAILABLE:
return {'error': 'scikit-learn non disponible'}
logger.info(
f"Début entraînement CNNImageStrategyModel pour "
f"{self.symbol}/{self.timeframe} — seq_len={self.seq_len}"
)
# 1. Génération des labels (ATR-based, cohérent avec MLStrategyModel)
gen = LabelGenerator(horizon=self.horizon)
labels = gen.generate_atr_based(
data,
atr_tp_mult=self.tp_atr_mult,
atr_sl_mult=self.sl_atr_mult,
)
# Encodage [-1,0,1] → [0,1,2] pour CrossEntropyLoss
y_enc = (labels + 1).values # np.ndarray
# 2. Encodage des images (fenêtres glissantes)
logger.info(f" Rendu des images (seq_len={self.seq_len})...")
renderer = CandlestickImageRenderer()
X = renderer.encode(data, seq_len=self.seq_len) # (N, 3, 128, 128)
# 3. Alignement X et y
# encode() produit N = len(data) - seq_len images
# La i-ème image correspond à data.iloc[i : i + seq_len]
# Le label associé est y_enc[i + seq_len - 1] (dernier point de la fenêtre)
n_images = len(X)
if n_images == 0:
return {
'error': f'Pas assez de données : {len(data)} barres < seq_len={self.seq_len}'
}
# Index des labels pour chaque image : fenêtre i → label à l'indice i + seq_len - 1
label_indices = np.arange(self.seq_len - 1, self.seq_len - 1 + n_images)
# Supprimer les barres de fin (labels non fiables, horizon tronqué)
valid_mask = label_indices < (len(y_enc) - self.horizon)
X = X[valid_mask]
y = y_enc[label_indices[valid_mask]]
n_samples = len(X)
if n_samples < 50:
return {'error': f'Trop peu d\'échantillons valides : {n_samples}'}
logger.info(
f" {n_samples} images valides — "
f"LONG={(y==2).sum()}, SHORT={(y==0).sum()}, NEUTRAL={(y==1).sum()}"
)
# 4. Calcul des class_weights (inversement proportionnels à la fréquence)
class_weights = self._compute_class_weights(y)
# 5. Walk-forward évaluation (2 folds)
wf_metrics = self._walk_forward_eval(X, y, class_weights, n_splits=2)
# 6. Entraînement final sur toutes les données
logger.info(" Entraînement final sur toutes les données...")
self.model = CandlestickCNN(n_classes=3)
self._train_model(self.model, X, y, class_weights)
self.is_trained = True
# Distribution des labels (décodage pour lisibilité)
label_dist = {
'long': int((y == 2).sum()),
'short': int((y == 0).sum()),
'neutral': int((y == 1).sum()),
}
self.metadata = {
'symbol': self.symbol,
'timeframe': self.timeframe,
'seq_len': self.seq_len,
'trained_at': datetime.utcnow().isoformat(),
'n_samples': n_samples,
'tp_atr_mult': self.tp_atr_mult,
'sl_atr_mult': self.sl_atr_mult,
'horizon': self.horizon,
'label_dist': label_dist,
'wf_metrics': wf_metrics,
'mplfinance': MPLFINANCE_AVAILABLE,
}
# 7. Sauvegarde
self.save()
logger.info(
f"Entraînement terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}"
)
return self.metadata
# -------------------------------------------------------------------------
# Prédiction
# -------------------------------------------------------------------------
def predict(self, data: pd.DataFrame) -> Dict:
"""
Prédit le signal pour la dernière fenêtre disponible.
Args:
data: DataFrame OHLCV récent (minimum seq_len barres)
Returns:
Dict : {
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
'confidence': float [0..1] — max(P_LONG, P_SHORT),
'probas': {'long': float, 'short': float, 'neutral': float},
'tradeable': bool — confidence >= min_confidence et signal != 0,
}
"""
if not TORCH_AVAILABLE:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': 'PyTorch non disponible'
}
if not self.is_trained or self.model is None:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': 'Modèle non entraîné'
}
try:
renderer = CandlestickImageRenderer()
img = renderer.encode_last(data, seq_len=self.seq_len) # (1, 3, 128, 128)
except ValueError as e:
return {
'signal': 0, 'confidence': 0.0, 'tradeable': False,
'error': str(e)
}
# Conversion en Tensor
x_tensor = torch.from_numpy(img).float() # (1, 3, 128, 128)
# Prédiction
proba_arr = self.model.predict_proba(x_tensor) # (1, 3)
proba = proba_arr[0] # (3,) : [P_SHORT, P_NEUTRAL, P_LONG]
pred_class = int(np.argmax(proba))
signal = self.CLASS_MAP[pred_class] # Décodage : 0→-1, 1→0, 2→1
probas = {
'short': float(proba[0]),
'neutral': float(proba[1]),
'long': float(proba[2]),
}
confidence = float(max(probas['long'], probas['short']))
return {
'signal': signal,
'confidence': confidence,
'probas': probas,
'tradeable': confidence >= self.min_confidence and signal != 0,
}
# -------------------------------------------------------------------------
# Sauvegarde / Chargement
# -------------------------------------------------------------------------
def save(self) -> Path:
"""
Sauvegarde le modèle et ses métadonnées sur disque.
Format :
{symbol}_{timeframe}.pt — state_dict PyTorch
{symbol}_{timeframe}_meta.json — métadonnées JSON
Returns:
Path vers le fichier .pt sauvegardé
Raises:
RuntimeError si le modèle n'est pas entraîné ou PyTorch indisponible
"""
if not TORCH_AVAILABLE:
raise RuntimeError("PyTorch non disponible")
if not self.is_trained or self.model is None:
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
model_id = f"{self.symbol}_{self.timeframe}"
model_path = self.MODELS_DIR / f"{model_id}.pt"
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
# Sauvegarde du state_dict + config pour pouvoir reconstruire le modèle
torch.save(
{
'state_dict': self.model.state_dict(),
'config': {
'symbol': self.symbol,
'timeframe': self.timeframe,
'seq_len': self.seq_len,
'tp_atr_mult': self.tp_atr_mult,
'sl_atr_mult': self.sl_atr_mult,
'horizon': self.horizon,
'min_confidence': self.min_confidence,
'epochs': self.epochs,
'batch_size': self.batch_size,
'lr': self.lr,
'patience': self.patience,
},
'metadata': self.metadata,
},
model_path
)
with open(meta_path, 'w') as f:
json.dump(self.metadata, f, indent=2, default=str)
logger.info(f"Modèle CNN image sauvegardé : {model_path}")
return model_path
@classmethod
def load(cls, symbol: str, timeframe: str) -> 'CNNImageStrategyModel':
"""
Charge un modèle existant depuis le disque.
Args:
symbol: Paire (ex: 'EURUSD')
timeframe: Timeframe (ex: '1h')
Returns:
Instance CNNImageStrategyModel prête à prédire
Raises:
RuntimeError si PyTorch non disponible
FileNotFoundError si le modèle n'existe pas
"""
if not TORCH_AVAILABLE:
raise RuntimeError("PyTorch non disponible")
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
if not model_path.exists():
raise FileNotFoundError(f"Modèle non trouvé : {model_path}")
checkpoint = torch.load(model_path, map_location='cpu')
cfg = checkpoint.get('config', {})
instance = cls(
symbol = cfg.get('symbol', symbol),
timeframe = cfg.get('timeframe', timeframe),
seq_len = cfg.get('seq_len', 64),
tp_atr_mult = cfg.get('tp_atr_mult', 2.0),
sl_atr_mult = cfg.get('sl_atr_mult', 1.0),
horizon = cfg.get('horizon', 30),
min_confidence = cfg.get('min_confidence', 0.55),
epochs = cfg.get('epochs', 50),
batch_size = cfg.get('batch_size', 32),
lr = cfg.get('lr', 1e-3),
patience = cfg.get('patience', 7),
)
# Reconstruction du modèle et chargement des poids
instance.model = CandlestickCNN(n_classes=3)
instance.model.load_state_dict(checkpoint['state_dict'])
instance.model.eval()
instance.is_trained = True
instance.metadata = checkpoint.get('metadata', {})
logger.info(f"Modèle CNN image chargé depuis {model_path}")
return instance
@staticmethod
def list_trained_models() -> List[Dict]:
"""
Retourne la liste des modèles entraînés disponibles.
Returns:
Liste de dicts avec symbol, timeframe, trained_at, n_samples, wf_accuracy
"""
if not MODELS_DIR.exists():
return []
models = []
for f in MODELS_DIR.glob("*_meta.json"):
try:
with open(f) as fp:
meta = json.load(fp)
# Nom du fichier : EURUSD_1h_meta.json → EURUSD, 1h
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
parts = stem.split('_', 1)
models.append({
'symbol': meta.get('symbol', parts[0] if parts else '?'),
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
'trained_at': meta.get('trained_at', '?'),
'n_samples': meta.get('n_samples', 0),
'seq_len': meta.get('seq_len', 64),
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
'mplfinance': meta.get('mplfinance', False),
})
except Exception:
pass
return models
def get_feature_importance(self) -> List[Dict]:
"""
Retourne l'importance des features.
Note : Les CNN sont des boîtes noires — pas d'importance de feature
interprétable comme XGBoost. Retourne une liste vide.
Returns:
[] — CNN = boîte noire visuelle
"""
return []
# -------------------------------------------------------------------------
# Entraînement interne
# -------------------------------------------------------------------------
def _train_model(
self,
model: 'CandlestickCNN',
X: np.ndarray,
y: np.ndarray,
class_weights: 'torch.Tensor',
) -> None:
"""
Entraîne le modèle CNN sur les données fournies avec early stopping.
Args:
model: Instance CandlestickCNN à entraîner
X: Images (N, 3, 128, 128) float32
y: Labels encodés (N,) int, valeurs {0, 1, 2}
class_weights: Poids de classes (3,) pour CrossEntropyLoss
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
class_weights = class_weights.to(device)
# Conversion en Tensors
X_t = torch.from_numpy(X).float()
y_t = torch.from_numpy(y.astype(np.int64))
# Split train/validation (80/20 temporel)
n_train = int(len(X_t) * 0.8)
X_tr, X_val = X_t[:n_train], X_t[n_train:]
y_tr, y_val = y_t[:n_train], y_t[n_train:]
# DataLoaders
train_ds = TensorDataset(X_tr, y_tr)
val_ds = TensorDataset(X_val, y_val)
train_dl = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False)
val_dl = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
# Optimiseur et critère
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
criterion = nn.CrossEntropyLoss(weight=class_weights)
best_val_loss = float('inf')
patience_count = 0
best_state = None
for epoch in range(self.epochs):
# --- Phase entraînement ---
model.train()
train_loss = 0.0
for X_batch, y_batch in train_dl:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
optimizer.step()
train_loss += loss.item() * len(X_batch)
train_loss /= max(len(train_ds), 1)
# --- Phase validation ---
model.eval()
val_loss = 0.0
with torch.no_grad():
for X_batch, y_batch in val_dl:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
logits = model(X_batch)
loss = criterion(logits, y_batch)
val_loss += loss.item() * len(X_batch)
val_loss /= max(len(val_ds), 1)
logger.debug(
f" Epoch {epoch+1}/{self.epochs}"
f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}"
)
# Early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_count = 0
best_state = {
k: v.cpu().clone() for k, v in model.state_dict().items()
}
else:
patience_count += 1
if patience_count >= self.patience:
logger.info(
f" Early stopping à l'époque {epoch+1} "
f"(patience={self.patience}, best_val_loss={best_val_loss:.4f})"
)
break
# Restauration des meilleurs poids
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
model.to('cpu')
def _walk_forward_eval(
self,
X: np.ndarray,
y: np.ndarray,
class_weights: 'torch.Tensor',
n_splits: int = 2,
) -> Dict:
"""
Évalue le modèle en cross-validation temporelle (walk-forward).
Args:
X: Images (N, 3, 128, 128)
y: Labels encodés (N,)
class_weights: Poids de classes pour CrossEntropyLoss
n_splits: Nombre de folds temporels
Returns:
Dict : avg_accuracy, avg_precision, fold_accuracies
"""
tscv = TimeSeriesSplit(n_splits=n_splits)
accuracies = []
precisions = []
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
X_tr, X_te = X[train_idx], X[test_idx]
y_tr, y_te = y[train_idx], y[test_idx]
# Modèle temporaire pour ce fold
fold_model = CandlestickCNN(n_classes=3)
self._train_model(fold_model, X_tr, y_tr, class_weights)
# Évaluation
x_tensor = torch.from_numpy(X_te).float()
proba_arr = fold_model.predict_proba(x_tensor) # (N_te, 3)
y_pred = np.argmax(proba_arr, axis=1)
acc = float((y_pred == y_te).mean())
accuracies.append(acc)
# Précision sur signaux directionnels (LONG=2 et SHORT=0, exclure NEUTRAL=1)
mask = (y_te != 1) | (y_pred != 1)
if mask.sum() > 0 and SKLEARN_AVAILABLE:
prec, _, _, _ = precision_recall_fscore_support(
y_te[mask], y_pred[mask],
average='macro', zero_division=0
)
precisions.append(float(prec))
else:
precisions.append(0.0)
logger.info(
f" Fold {fold+1}/{n_splits} : acc={acc:.2%}, "
f"prec={precisions[-1]:.2%}"
)
return {
'avg_accuracy': float(np.mean(accuracies)) if accuracies else 0.0,
'avg_precision': float(np.mean(precisions)) if precisions else 0.0,
'fold_accuracies': [float(a) for a in accuracies],
}
@staticmethod
def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor':
"""
Calcule les poids de classes inversement proportionnels à leur fréquence.
Permet de compenser les déséquilibres (ex: beaucoup de NEUTRAL, peu de trades).
Args:
y: Labels encodés {0, 1, 2}
Returns:
torch.Tensor de forme (3,) — poids pour CrossEntropyLoss
"""
weights = np.ones(3, dtype=np.float32)
n_total = len(y)
if n_total == 0:
return torch.from_numpy(weights)
for cls in range(3):
count = (y == cls).sum()
if count > 0:
# Poids inversement proportionnel : classes rares → poids plus élevé
weights[cls] = n_total / (3.0 * count)
# Normalisation pour que la somme des poids = 3
weights = weights / weights.mean()
return torch.from_numpy(weights)

View File

@@ -34,8 +34,9 @@ class EnsembleModel:
"""
DEFAULT_WEIGHTS = {
'xgboost': 0.40,
'cnn': 0.60,
'xgboost': 0.30,
'cnn': 0.30,
'cnn_image': 0.40, # CNN Vision — favorisé (patterns visuels bruts)
'rl': 0.00, # Réservé Phase 4d
}
@@ -52,6 +53,9 @@ class EnsembleModel:
# Modèles attachés (duck typing : .predict(df), .is_trained)
self._models: Dict[str, Any] = {}
# Référence directe au modèle CNN Image (accès rapide pour auto-attach)
self._cnn_image_model = None
logger.info(
f"EnsembleModel initialisé — poids={self.weights}, "
f"seuil={self.min_confidence}, accord_requis={self.require_agreement}"
@@ -68,6 +72,11 @@ class EnsembleModel:
"""Attache un CNNStrategyModel entraîné."""
self._attach('cnn', model)
def attach_cnn_image(self, model) -> None:
"""Attache un CNNImageStrategyModel entraîné (Phase 4c-bis)."""
self._cnn_image_model = model
self._attach('cnn_image', model)
def attach_rl(self, model) -> None:
"""Attache un agent RL (Phase 4d)."""
self._attach('rl', model)
@@ -198,12 +207,13 @@ class EnsembleModel:
# Statut et configuration
# ------------------------------------------------------------------
def is_ready(self) -> bool:
"""True si au moins 2 modèles sont attachés et entraînés."""
"""True si au moins 2 modèles parmi {xgboost, cnn, cnn_image} sont attachés et entraînés."""
eligible_names = {'xgboost', 'cnn', 'cnn_image'}
trained = sum(
1 for m in self._models.values()
if m.is_trained and self.weights.get(
next(k for k, v in self._models.items() if v is m), 0
) > 0
1 for name, model in self._models.items()
if name in eligible_names
and model.is_trained
and self.weights.get(name, 0.0) > 0
)
return trained >= 2

View File

@@ -0,0 +1,3 @@
from src.strategies.cnn_image_driven.cnn_image_strategy import CNNImageDrivenStrategy
__all__ = ['CNNImageDrivenStrategy']

View File

@@ -0,0 +1,258 @@
"""
CNN Image-Driven Strategy — Stratégie pilotée par un CNN à vision par ordinateur.
Contrairement au CNN 1D qui analyse des séquences numériques brutes, cette
stratégie convertit les bougies OHLCV en images de graphiques (rendu mplfinance
128×128 RGB), puis utilise un réseau Conv2D pour reconnaître les patterns
visuels — exactement comme un trader devant TradingView.
Patterns détectés sans feature engineering explicite :
- Bougies marteau, étoile filante, doji
- Double top, double bottom
- Rebonds sur supports/résistances visuels
- Momentum : grande bougie pleine après consolidation
- Divergences prix/volume visibles
Fonctionnement :
1. Le modèle CNN-Image est chargé depuis le disque (entraîné via POST /trading/train-cnn-image)
2. À chaque barre, les seq_len dernières bougies sont rendues en image 128×128
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
Intégration :
- Compatible avec StrategyEngine (même interface que CNNDrivenStrategy)
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
- Requiert PyTorch + mplfinance + Pillow
"""
import logging
from datetime import datetime, timezone
from typing import Dict, Optional
import numpy as np
import pandas as pd
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
try:
from src.ml.cnn_image import CNNImageStrategyModel
CNN_IMAGE_AVAILABLE = True
except ImportError:
CNN_IMAGE_AVAILABLE = False
logger = logging.getLogger(__name__)
class CNNImageDrivenStrategy(BaseStrategy):
"""
Stratégie de trading pilotée par un CNN Conv2D pré-entraîné sur images de graphiques.
Le modèle apprend les patterns visuels des chandeliers directement depuis
des images 128×128 RGB sans features pré-calculées :
- Double bottom / double top
- Configurations chandeliers (marteau, doji, étoile filante)
- Rebonds sur zones de support/résistance visibles
- Structures de momentum multi-barres
- Divergences volume/prix visuelles
Args:
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
Config keys supplémentaires (optionnelles) :
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
seq_len: Nombre de bougies par image (défaut: 64)
auto_load: Charger automatiquement le modèle existant (défaut: True)
"""
STRATEGY_NAME = 'cnn_image_driven'
def __init__(self, config: Dict):
super().__init__(config)
self.symbol = config.get('symbol', 'EURUSD')
self.min_confidence = config.get('min_confidence', 0.55)
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
self.seq_len = config.get('seq_len', 64)
self.cnn_image_model: Optional['CNNImageStrategyModel'] = None
if not CNN_IMAGE_AVAILABLE:
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
return
# Tentative de chargement automatique du modèle existant
if config.get('auto_load', True):
self._try_load_model()
# -------------------------------------------------------------------------
# Interface BaseStrategy
# -------------------------------------------------------------------------
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
"""
Génère un signal de trading via le modèle CNN Image.
Args:
market_data: DataFrame OHLCV (minimum seq_len barres)
Returns:
Signal si le modèle est confiant, None sinon
"""
if not CNN_IMAGE_AVAILABLE:
logger.debug("CNN Image Strategy : PyTorch/mplfinance non disponible, aucun signal")
return None
if self.cnn_image_model is None or not self.cnn_image_model.is_trained:
logger.debug("CNN Image Strategy : modèle non chargé, aucun signal")
return None
if len(market_data) < self.seq_len:
return None
try:
result = self.cnn_image_model.predict(market_data)
except Exception as e:
logger.warning(f"CNN Image Strategy predict error : {e}")
return None
if not result.get('tradeable', False):
return None
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
confidence = result['confidence']
# Prix et ATR pour SL/TP
last_close = float(market_data['close'].iloc[-1])
atr = self._compute_atr(market_data)
if atr <= 0:
return None
if signal_dir == 1:
direction = 'LONG'
stop_loss = last_close - self.sl_atr_mult * atr
take_profit = last_close + self.tp_atr_mult * atr
elif signal_dir == -1:
direction = 'SHORT'
stop_loss = last_close + self.sl_atr_mult * atr
take_profit = last_close - self.tp_atr_mult * atr
else:
return None
signal = Signal(
symbol = self.symbol,
direction = direction,
entry_price = last_close,
stop_loss = stop_loss,
take_profit = take_profit,
confidence = confidence,
timestamp = datetime.now(timezone.utc),
strategy = self.STRATEGY_NAME,
metadata = {
'probas': result.get('probas', {}),
'seq_len': self.seq_len,
'atr': atr,
'tp_atr_mult': self.tp_atr_mult,
'sl_atr_mult': self.sl_atr_mult,
},
)
logger.info(
f"CNN Image Signal : {direction} {self.symbol} | "
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
f"confidence={confidence:.2%}"
)
return signal
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
"""Retourne les données telles quelles — le CNN Image travaille sur des rendus d'images."""
return data
def update_params(self, params: Dict) -> None:
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
if 'min_confidence' in params:
self.min_confidence = params['min_confidence']
if self.cnn_image_model:
self.cnn_image_model.min_confidence = params['min_confidence']
if 'tp_atr_mult' in params:
self.tp_atr_mult = params['tp_atr_mult']
if 'sl_atr_mult' in params:
self.sl_atr_mult = params['sl_atr_mult']
if 'seq_len' in params:
self.seq_len = params['seq_len']
logger.info(f"CNN Image Strategy params mis à jour : {params}")
# -------------------------------------------------------------------------
# Gestion du modèle
# -------------------------------------------------------------------------
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
"""
Charge un modèle CNN Image depuis le disque.
Args:
symbol: Paire (défaut: self.symbol)
timeframe: Timeframe (défaut: self.config.timeframe)
Returns:
True si chargement réussi
"""
if not CNN_IMAGE_AVAILABLE:
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
return False
sym = symbol or self.symbol
tf = timeframe or self.config.timeframe
try:
self.cnn_image_model = CNNImageStrategyModel.load(sym, tf)
logger.info(f"Modèle CNN Image chargé : {sym}/{tf}")
return True
except FileNotFoundError:
logger.info(f"Aucun modèle CNN Image trouvé pour {sym}/{tf}")
return False
except Exception as e:
logger.error(f"Erreur chargement modèle CNN Image : {e}")
return False
def attach_model(self, model: 'CNNImageStrategyModel') -> None:
"""Attache directement un modèle CNN Image (après entraînement via API)."""
self.cnn_image_model = model
self.symbol = model.symbol
logger.info(f"Modèle CNN Image attaché : {model.symbol}/{model.timeframe}")
def is_ready(self) -> bool:
"""Retourne True si le modèle CNN Image est chargé et entraîné."""
if not CNN_IMAGE_AVAILABLE:
return False
return self.cnn_image_model is not None and self.cnn_image_model.is_trained
def get_model_info(self) -> Dict:
"""Retourne les métadonnées du modèle CNN Image actif."""
if not CNN_IMAGE_AVAILABLE:
return {'status': 'PyTorch/mplfinance non disponible'}
if not self.is_ready():
return {'status': 'non entraîné'}
meta = self.cnn_image_model.metadata.copy()
meta['is_ready'] = True
meta['seq_len'] = self.seq_len
return meta
# -------------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------------
def _try_load_model(self) -> None:
"""Tente un chargement silencieux du modèle au démarrage."""
try:
self.load_model()
except Exception:
pass
@staticmethod
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
"""Calcule l'ATR moyen sur les dernières barres."""
if len(df) < period + 1:
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
h, l, pc = df['high'], df['low'], df['close'].shift(1)
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
atr = tr.rolling(period).mean().iloc[-1]
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])