feat: Phase 4c-bis — CNN image-based (analyse visuelle graphiques chandeliers)
## Nouveaux modules
### src/ml/cnn_image/
- chart_renderer.py : CandlestickImageRenderer — OHLCV → images 128×128 RGB (mplfinance)
Fond #0d1117, bougies vertes/rouges, volume, sans axes, rendu en mémoire
Fallback 2D si mplfinance absent
- cnn_image_model.py : CandlestickCNN — Conv2D 4-blocs (3→32→64→128→256) + AvgPool + Dense(3)
- cnn_image_strategy_model.py : CNNImageStrategyModel — même interface que MLStrategyModel
### src/strategies/cnn_image_driven/
- cnn_image_strategy.py : CNNImageDrivenStrategy(BaseStrategy), SL/TP ATR, seq_len=64
## Modifications
- ensemble_model.py : attach_cnn_image(), poids XGB=0.30/CNN1D=0.30/CNNImage=0.40
- trading.py : POST /train-cnn-image, GET /train-cnn-image/{id}, GET /cnn-image-models
- docker/requirements/api.txt : mplfinance>=0.12.10b0, Pillow>=10.0.0
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -29,3 +29,7 @@ joblib>=1.3.0
|
||||
|
||||
# ML — Deep Learning (CNN pour patterns chandeliers)
|
||||
torch>=2.0.0
|
||||
|
||||
# Visualisation — graphiques financiers et traitement d'images
|
||||
mplfinance>=0.12.10b0
|
||||
Pillow>=10.0.0
|
||||
|
||||
@@ -1230,3 +1230,191 @@ async def get_ensemble_status():
|
||||
status["ensemble_info"] = {"error": "Impossible de récupérer le statut"}
|
||||
|
||||
return status
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CNN IMAGE STRATEGY — Entraînement et gestion des modèles CNN-Image
|
||||
# =============================================================================
|
||||
|
||||
try:
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
CNN_IMAGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CNN_IMAGE_AVAILABLE = False
|
||||
|
||||
# Stockage en mémoire des jobs d'entraînement CNN Image
|
||||
_cnn_image_train_jobs: Dict[str, dict] = {}
|
||||
|
||||
|
||||
class CNNImageTrainRequest(BaseModel):
|
||||
"""Requête d'entraînement du modèle CNN Image (Conv2D vision)."""
|
||||
symbol: str = "EURUSD"
|
||||
timeframe: str = "1h"
|
||||
period: str = "2y"
|
||||
seq_len: int = 64
|
||||
tp_atr_mult: float = 2.0
|
||||
sl_atr_mult: float = 1.0
|
||||
horizon: int = 30
|
||||
min_confidence: float = 0.55
|
||||
|
||||
|
||||
class CNNImageTrainResponse(BaseModel):
|
||||
"""Réponse d'un job d'entraînement CNN Image."""
|
||||
job_id: str
|
||||
status: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
wf_accuracy: Optional[float] = None
|
||||
wf_precision: Optional[float] = None
|
||||
label_dist: Optional[dict] = None
|
||||
n_samples: Optional[int] = None
|
||||
trained_at: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
async def _run_cnn_image_train_task(job_id: str, request: CNNImageTrainRequest) -> None:
|
||||
"""Tâche d'entraînement CNN Image exécutée en arrière-plan."""
|
||||
_cnn_image_train_jobs[job_id]["status"] = "running"
|
||||
try:
|
||||
from src.data.data_service import DataService
|
||||
from src.utils.config_loader import ConfigLoader
|
||||
from datetime import timedelta
|
||||
|
||||
config = ConfigLoader.load_all()
|
||||
data_service = DataService(config)
|
||||
|
||||
end_date = datetime.now()
|
||||
period_map = {'y': 365, 'm': 30, 'd': 1}
|
||||
unit = request.period[-1]
|
||||
value = int(request.period[:-1])
|
||||
start_date = end_date - timedelta(days=value * period_map.get(unit, 1))
|
||||
|
||||
df = await data_service.get_historical_data(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
start_date = start_date,
|
||||
end_date = end_date,
|
||||
)
|
||||
if df is None or len(df) < 200:
|
||||
raise ValueError(
|
||||
f"Données insuffisantes : {len(df) if df is not None else 0} barres (min 200)"
|
||||
)
|
||||
|
||||
# Entraînement dans un thread (opération CPU-bound / GPU-bound)
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(None, _sync_cnn_image_train, df, request)
|
||||
|
||||
_cnn_image_train_jobs[job_id].update({
|
||||
"status": "completed",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
"n_samples": result.get("n_samples"),
|
||||
"wf_accuracy": result.get("wf_metrics", {}).get("avg_accuracy"),
|
||||
"wf_precision": result.get("wf_metrics", {}).get("avg_precision"),
|
||||
"label_dist": result.get("label_dist"),
|
||||
"trained_at": result.get("trained_at"),
|
||||
})
|
||||
|
||||
# Auto-attachement à la stratégie cnn_image_driven active si elle existe
|
||||
_attach_cnn_image_model_to_strategy(request)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Erreur entraînement CNN Image job {job_id} : {exc}", exc_info=True)
|
||||
_cnn_image_train_jobs[job_id]["status"] = "failed"
|
||||
_cnn_image_train_jobs[job_id]["error"] = str(exc)
|
||||
|
||||
|
||||
def _sync_cnn_image_train(df, request: CNNImageTrainRequest) -> dict:
|
||||
"""Wrapper synchrone pour CNNImageStrategyModel.train() (exécuté dans un thread)."""
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
model = CNNImageStrategyModel(
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
seq_len = request.seq_len,
|
||||
tp_atr_mult = request.tp_atr_mult,
|
||||
sl_atr_mult = request.sl_atr_mult,
|
||||
horizon = request.horizon,
|
||||
min_confidence = request.min_confidence,
|
||||
)
|
||||
return model.train(df)
|
||||
|
||||
|
||||
def _attach_cnn_image_model_to_strategy(request: CNNImageTrainRequest) -> None:
|
||||
"""Attache le modèle CNN Image entraîné à la stratégie cnn_image_driven active (paper trading)."""
|
||||
try:
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
from src.strategies.cnn_image_driven import CNNImageDrivenStrategy
|
||||
|
||||
engine = _paper_state.get("engine")
|
||||
if engine and hasattr(engine, 'strategy_engine'):
|
||||
strat = engine.strategy_engine.strategies.get('cnn_image_driven')
|
||||
if strat and isinstance(strat, CNNImageDrivenStrategy):
|
||||
model = CNNImageStrategyModel.load(request.symbol, request.timeframe)
|
||||
strat.attach_model(model)
|
||||
logger.info("Modèle CNN Image attaché à la stratégie cnn_image_driven active")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auto-attach modèle CNN Image ignoré : {e}")
|
||||
|
||||
|
||||
@router.post("/train-cnn-image", response_model=CNNImageTrainResponse,
|
||||
summary="Entraîner le modèle CNN Image (Conv2D vision)")
|
||||
async def train_cnn_image(request: CNNImageTrainRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
Lance l'entraînement du modèle CNN Image en arrière-plan.
|
||||
|
||||
Le CNN Image Conv2D apprend les patterns visuels des graphiques de chandeliers
|
||||
(bougies marteau, double top/bottom, rebonds S/R, momentum...) depuis des
|
||||
images 128×128 RGB générées par mplfinance — aucune feature pré-calculée.
|
||||
|
||||
- Retourne un `job_id` à interroger via `GET /trading/train-cnn-image/{job_id}`
|
||||
- Le modèle est sauvegardé sur disque après entraînement
|
||||
- Si un paper trading cnn_image_driven est actif, le modèle lui est automatiquement attaché
|
||||
"""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
raise HTTPException(
|
||||
503,
|
||||
detail="PyTorch + mplfinance requis — rebuilder le container trading-api"
|
||||
)
|
||||
|
||||
job_id = str(uuid.uuid4())
|
||||
_cnn_image_train_jobs[job_id] = {
|
||||
"status": "pending",
|
||||
"symbol": request.symbol,
|
||||
"timeframe": request.timeframe,
|
||||
}
|
||||
|
||||
background_tasks.add_task(_run_cnn_image_train_task, job_id, request)
|
||||
|
||||
return CNNImageTrainResponse(
|
||||
job_id = job_id,
|
||||
status = "pending",
|
||||
symbol = request.symbol,
|
||||
timeframe = request.timeframe,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/train-cnn-image/{job_id}", response_model=CNNImageTrainResponse,
|
||||
summary="Résultat entraînement CNN Image")
|
||||
def get_cnn_image_train_status(job_id: str):
|
||||
"""Retourne l'état d'un job d'entraînement CNN Image."""
|
||||
job = _cnn_image_train_jobs.get(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(404, detail=f"Job {job_id} introuvable")
|
||||
return CNNImageTrainResponse(job_id=job_id, **job)
|
||||
|
||||
|
||||
@router.get("/cnn-image-models", summary="Liste des modèles CNN Image entraînés")
|
||||
def list_cnn_image_models():
|
||||
"""
|
||||
Retourne la liste de tous les modèles CNN Image disponibles sur disque,
|
||||
avec leurs métriques (accuracy, date d'entraînement, nombre de samples...).
|
||||
"""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
return {
|
||||
"error": "PyTorch + mplfinance requis — rebuilder le container trading-api",
|
||||
"models": [],
|
||||
"count": 0,
|
||||
}
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
models = CNNImageStrategyModel.list_trained_models()
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
11
src/ml/cnn_image/__init__.py
Normal file
11
src/ml/cnn_image/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Module CNN Image-Based — Analyse visuelle de graphiques en chandeliers.
|
||||
|
||||
Ce module implémente un CNN Conv2D qui analyse des graphiques de chandeliers
|
||||
rendus comme de vraies images (128×128 RGB), pour reconnaître les patterns
|
||||
visuels exactement comme un trader humain devant TradingView.
|
||||
"""
|
||||
|
||||
from src.ml.cnn_image.cnn_image_strategy_model import CNNImageStrategyModel
|
||||
|
||||
__all__ = ['CNNImageStrategyModel']
|
||||
322
src/ml/cnn_image/chart_renderer.py
Normal file
322
src/ml/cnn_image/chart_renderer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
CandlestickImageRenderer — Convertit des données OHLCV en images de graphiques.
|
||||
|
||||
Ce module transforme des séquences de bougies OHLCV en images 128×128 RGB
|
||||
qui peuvent être passées à un CNN Conv2D pour l'analyse visuelle des patterns.
|
||||
|
||||
Rendu :
|
||||
- Fond noir (#0d1117), style TradingView
|
||||
- Bougies vertes (#26a69a) pour la hausse, rouges (#ef5350) pour la baisse
|
||||
- Volume en bas de l'image (via mplfinance)
|
||||
- Pas d'axes, pas de labels, pas de titre
|
||||
- Taille fixe : 128×128 pixels, 3 canaux RGB
|
||||
|
||||
Si mplfinance ou PIL ne sont pas disponibles, un rendu de fallback basique
|
||||
encode les données OHLCV numériquement sous forme d'image 2D.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Détection optionnelle de mplfinance et PIL ---
|
||||
try:
|
||||
import mplfinance as mpf
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Backend non-interactif (pas d'écran requis)
|
||||
import matplotlib.pyplot as plt
|
||||
MPLFINANCE_AVAILABLE = True
|
||||
logger.debug("mplfinance disponible — rendu haute qualité activé")
|
||||
except ImportError:
|
||||
MPLFINANCE_AVAILABLE = False
|
||||
logger.warning("mplfinance non disponible — utilisation du rendu fallback")
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
PIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIL_AVAILABLE = False
|
||||
logger.warning("Pillow (PIL) non disponible — rendu fallback uniquement")
|
||||
|
||||
|
||||
# Taille cible des images en pixels
|
||||
IMAGE_SIZE = 128
|
||||
|
||||
|
||||
class CandlestickImageRenderer:
|
||||
"""
|
||||
Convertit des données OHLCV en images de graphiques en chandeliers.
|
||||
|
||||
Chaque image est un instantané visuel de `seq_len` bougies consécutives,
|
||||
rendu avec mplfinance (style fond noir, bougies colorées, volume en bas).
|
||||
Le résultat est normalisé en float32 dans [0, 1].
|
||||
|
||||
Args:
|
||||
image_size: Taille carrée de l'image en pixels (défaut 128)
|
||||
"""
|
||||
|
||||
def __init__(self, image_size: int = IMAGE_SIZE):
|
||||
self.image_size = image_size
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface publique
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def encode(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||
"""
|
||||
Encode toutes les fenêtres glissantes du DataFrame en images.
|
||||
|
||||
Produit N = len(df) - seq_len fenêtres, chacune rendue en image
|
||||
128×128 RGB normalisée [0, 1].
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV avec colonnes open/high/low/close/volume
|
||||
seq_len: Nombre de bougies par fenêtre (défaut 64)
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (N, 3, 128, 128), dtype float32, valeurs [0, 1]
|
||||
Retourne un tableau vide (0, 3, 128, 128) si df est trop court.
|
||||
"""
|
||||
df = self._prepare_df(df)
|
||||
n_windows = len(df) - seq_len
|
||||
|
||||
if n_windows <= 0:
|
||||
logger.warning(
|
||||
f"DataFrame trop court ({len(df)} barres) pour seq_len={seq_len}"
|
||||
)
|
||||
return np.zeros((0, 3, self.image_size, self.image_size), dtype=np.float32)
|
||||
|
||||
images = np.zeros(
|
||||
(n_windows, 3, self.image_size, self.image_size), dtype=np.float32
|
||||
)
|
||||
|
||||
for i in range(n_windows):
|
||||
window = df.iloc[i: i + seq_len]
|
||||
try:
|
||||
img = self._render_single(window)
|
||||
images[i] = img
|
||||
except Exception as e:
|
||||
# Fenêtre problématique → on laisse des zéros pour cette position
|
||||
logger.debug(f"Erreur rendu fenêtre {i} : {e}")
|
||||
|
||||
return images
|
||||
|
||||
def encode_last(self, df: pd.DataFrame, seq_len: int = 64) -> np.ndarray:
|
||||
"""
|
||||
Encode uniquement la dernière fenêtre du DataFrame.
|
||||
|
||||
Utilisé pour la prédiction en temps réel : retourne (1, 3, 128, 128).
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV
|
||||
seq_len: Nombre de bougies pour la dernière fenêtre (défaut 64)
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (1, 3, 128, 128), dtype float32, valeurs [0, 1]
|
||||
|
||||
Raises:
|
||||
ValueError si df est trop court pour seq_len bougies
|
||||
"""
|
||||
df = self._prepare_df(df)
|
||||
|
||||
if len(df) < seq_len:
|
||||
raise ValueError(
|
||||
f"DataFrame insuffisant : {len(df)} barres < seq_len={seq_len}"
|
||||
)
|
||||
|
||||
window = df.iloc[-seq_len:]
|
||||
img = self._render_single(window)
|
||||
# Ajouter dimension batch
|
||||
return img[np.newaxis, ...] # (1, 3, H, W)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Rendu d'une fenêtre
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _render_single(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Rend une fenêtre OHLCV en image numpy (3, 128, 128).
|
||||
|
||||
Utilise mplfinance si disponible, sinon fallback encodage numérique.
|
||||
|
||||
Args:
|
||||
df_window: DataFrame OHLCV pour une fenêtre (seq_len barres)
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (3, 128, 128), float32, valeurs [0, 1]
|
||||
"""
|
||||
if MPLFINANCE_AVAILABLE and PIL_AVAILABLE:
|
||||
return self._render_with_mplfinance(df_window)
|
||||
else:
|
||||
return self._render_fallback(df_window)
|
||||
|
||||
def _render_with_mplfinance(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Rendu haute qualité via mplfinance.
|
||||
|
||||
Style : fond noir (#0d1117), bougies vertes/rouges, volume en bas,
|
||||
aucun axe ni label ni titre. Image 128×128 RGB.
|
||||
"""
|
||||
# Style personnalisé : fond noir, bougies colorées TradingView
|
||||
style = mpf.make_mpf_style(
|
||||
base_mpf_style='nightclouds',
|
||||
marketcolors=mpf.make_marketcolors(
|
||||
up='#26a69a', # vert TradingView (hausse)
|
||||
down='#ef5350', # rouge TradingView (baisse)
|
||||
wick={'up': '#26a69a', 'down': '#ef5350'},
|
||||
edge={'up': '#26a69a', 'down': '#ef5350'},
|
||||
volume={'up': '#26a69a', 'down': '#ef5350'},
|
||||
),
|
||||
facecolor='#0d1117', # Fond noir GitHub-style
|
||||
figcolor='#0d1117',
|
||||
gridcolor='#0d1117',
|
||||
)
|
||||
|
||||
# Rendu en mémoire via BytesIO
|
||||
buf = io.BytesIO()
|
||||
|
||||
fig, axes = mpf.plot(
|
||||
df_window,
|
||||
type='candle',
|
||||
style=style,
|
||||
volume=True,
|
||||
axisoff=True, # Pas d'axes
|
||||
tight_layout=True,
|
||||
returnfig=True,
|
||||
figsize=(1.28, 1.28), # 128 DPI × 1.28 inch = 128 pixels
|
||||
)
|
||||
|
||||
# Suppression de tous les éléments décoratifs
|
||||
for ax in axes:
|
||||
ax.set_axis_off()
|
||||
ax.set_facecolor('#0d1117')
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
fig.savefig(
|
||||
buf,
|
||||
format='png',
|
||||
dpi=100,
|
||||
bbox_inches='tight',
|
||||
pad_inches=0,
|
||||
facecolor='#0d1117',
|
||||
)
|
||||
plt.close(fig)
|
||||
|
||||
buf.seek(0)
|
||||
pil_img = Image.open(buf).convert('RGB')
|
||||
pil_img = pil_img.resize(
|
||||
(self.image_size, self.image_size), Image.LANCZOS
|
||||
)
|
||||
|
||||
# Conversion en array numpy (H, W, 3) → (3, H, W), float32, [0,1]
|
||||
arr = np.array(pil_img, dtype=np.float32) / 255.0
|
||||
arr = arr.transpose(2, 0, 1) # HWC → CHW
|
||||
|
||||
return arr
|
||||
|
||||
def _render_fallback(self, df_window: pd.DataFrame) -> np.ndarray:
|
||||
"""
|
||||
Rendu de secours sans mplfinance : encode les données OHLCV
|
||||
directement en image 2D normalisée.
|
||||
|
||||
Chaque colonne de l'image correspond à une bougie :
|
||||
- Canal 0 (R) : position relative du close dans le range [low, high]
|
||||
- Canal 1 (G) : amplitude de la bougie (high - low normalisé)
|
||||
- Canal 2 (B) : volume normalisé
|
||||
|
||||
L'image est ensuite redimensionnée à image_size × image_size.
|
||||
"""
|
||||
cols = df_window[['open', 'high', 'low', 'close', 'volume']].copy()
|
||||
n = len(cols)
|
||||
|
||||
if n == 0:
|
||||
return np.zeros((3, self.image_size, self.image_size), dtype=np.float32)
|
||||
|
||||
# Normalisation par colonne
|
||||
highs = cols['high'].values.astype(np.float32)
|
||||
lows = cols['low'].values.astype(np.float32)
|
||||
closes = cols['close'].values.astype(np.float32)
|
||||
opens = cols['open'].values.astype(np.float32)
|
||||
vols = cols['volume'].values.astype(np.float32)
|
||||
|
||||
price_range = highs - lows
|
||||
price_range = np.where(price_range == 0, 1e-8, price_range)
|
||||
|
||||
# Canal R : position du close dans le range de la bougie [0, 1]
|
||||
close_pos = (closes - lows) / price_range
|
||||
|
||||
# Canal G : corps de la bougie (|close - open| / range)
|
||||
body = np.abs(closes - opens) / price_range
|
||||
|
||||
# Canal B : volume normalisé [0, 1]
|
||||
vol_max = vols.max()
|
||||
vol_norm = vols / (vol_max if vol_max > 0 else 1.0)
|
||||
|
||||
# Construction image : 3 × 1 × n → 3 × image_size × image_size
|
||||
# On crée une image de height=image_size, width=n puis on redimensionne
|
||||
img = np.zeros((3, 1, n), dtype=np.float32)
|
||||
img[0, 0, :] = close_pos
|
||||
img[1, 0, :] = body
|
||||
img[2, 0, :] = vol_norm
|
||||
|
||||
# Redimensionnement vers (3, image_size, image_size) via répétition
|
||||
# On étire chaque canal sur les deux dimensions
|
||||
img_resized = np.zeros(
|
||||
(3, self.image_size, self.image_size), dtype=np.float32
|
||||
)
|
||||
for c in range(3):
|
||||
# Répétition en hauteur (axe 0 du canal) et interpolation en largeur
|
||||
channel = img[c, 0, :] # (n,)
|
||||
# Interpolation 1D vers image_size
|
||||
x_orig = np.linspace(0, 1, n)
|
||||
x_new = np.linspace(0, 1, self.image_size)
|
||||
channel_resized = np.interp(x_new, x_orig, channel).astype(np.float32)
|
||||
# Étirer sur toute la hauteur
|
||||
img_resized[c] = np.tile(channel_resized, (self.image_size, 1))
|
||||
|
||||
return img_resized
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Utilitaires
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _prepare_df(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Normalise le DataFrame : colonnes en minuscules, index DatetimeIndex.
|
||||
|
||||
Args:
|
||||
df: DataFrame OHLCV brut
|
||||
|
||||
Returns:
|
||||
DataFrame nettoyé avec colonnes open/high/low/close/volume
|
||||
"""
|
||||
df = df.copy()
|
||||
df.columns = [c.lower() for c in df.columns]
|
||||
|
||||
# S'assurer que l'index est un DatetimeIndex (requis par mplfinance)
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
try:
|
||||
df.index = pd.to_datetime(df.index)
|
||||
except Exception:
|
||||
# Créer un index artificiel si la conversion échoue
|
||||
df.index = pd.date_range(
|
||||
start='2020-01-01', periods=len(df), freq='1h'
|
||||
)
|
||||
|
||||
# Conserver uniquement les colonnes OHLCV
|
||||
required = ['open', 'high', 'low', 'close', 'volume']
|
||||
for col in required:
|
||||
if col not in df.columns:
|
||||
df[col] = 0.0
|
||||
|
||||
df = df[required].dropna(subset=['open', 'high', 'low', 'close'])
|
||||
df = df.ffill().bfill()
|
||||
|
||||
return df
|
||||
163
src/ml/cnn_image/cnn_image_model.py
Normal file
163
src/ml/cnn_image/cnn_image_model.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
CandlestickCNN — Réseau de neurones convolutif pour l'analyse visuelle de graphiques.
|
||||
|
||||
Architecture Conv2D 4-blocs conçue pour reconnaître les patterns visuels dans
|
||||
des images de chandeliers 128×128 RGB (bougies, volumes, supports/résistances).
|
||||
|
||||
Input : (batch, 3, 128, 128) — images RGB normalisées [0, 1]
|
||||
Output : (batch, 3) — logits pour 3 classes : SHORT(0), NEUTRAL(1), LONG(2)
|
||||
|
||||
Classes encodées :
|
||||
0 → SHORT (-1)
|
||||
1 → NEUTRAL ( 0)
|
||||
2 → LONG (+1)
|
||||
|
||||
L'encodage +1 est cohérent avec MLStrategyModel et CNNStrategyModel.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Détection optionnelle de PyTorch ---
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
TORCH_AVAILABLE = True
|
||||
logger.debug("PyTorch disponible — CandlestickCNN activé")
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
# Stubs pour permettre l'import du module sans PyTorch
|
||||
nn = None
|
||||
torch = None
|
||||
logger.warning("PyTorch non disponible — CandlestickCNN désactivé")
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Définition conditionnelle du modèle
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
|
||||
class CandlestickCNN(nn.Module):
|
||||
"""
|
||||
CNN vision pour graphiques chandeliers 128×128 RGB.
|
||||
|
||||
Architecture :
|
||||
Bloc 1 : Conv2d(3→32, 3×3) + BatchNorm2d(32) + ReLU + MaxPool2d(2) → (32, 64, 64)
|
||||
Bloc 2 : Conv2d(32→64, 3×3) + BatchNorm2d(64) + ReLU + MaxPool2d(2) → (64, 32, 32)
|
||||
Bloc 3 : Conv2d(64→128,3×3) + BatchNorm2d(128) + ReLU + MaxPool2d(2) → (128, 16, 16)
|
||||
Bloc 4 : Conv2d(128→256,3×3)+ BatchNorm2d(256) + ReLU + AdaptiveAvgPool2d(1) → (256,)
|
||||
|
||||
Classifieur :
|
||||
Linear(256→128) + Dropout(0.4) + ReLU
|
||||
Linear(128→3)
|
||||
|
||||
Output : logits bruts (3 classes) — appliquer softmax pour les probabilités.
|
||||
|
||||
Args:
|
||||
n_classes: Nombre de classes de sortie (défaut 3 : SHORT/NEUTRAL/LONG)
|
||||
dropout: Taux de dropout dans le classifieur (défaut 0.4)
|
||||
"""
|
||||
|
||||
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
|
||||
super(CandlestickCNN, self).__init__()
|
||||
|
||||
# --- Bloc convolutif 1 : (3, 128, 128) → (32, 64, 64) ---
|
||||
self.bloc1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
# --- Bloc convolutif 2 : (32, 64, 64) → (64, 32, 32) ---
|
||||
self.bloc2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
# --- Bloc convolutif 3 : (64, 32, 32) → (128, 16, 16) ---
|
||||
self.bloc3 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
# --- Bloc convolutif 4 : (128, 16, 16) → (256, 1, 1) → (256,) ---
|
||||
self.bloc4 = nn.Sequential(
|
||||
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.AdaptiveAvgPool2d(output_size=1), # Global average pooling → (256, 1, 1)
|
||||
)
|
||||
|
||||
# --- Classifieur dense ---
|
||||
self.classifieur = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(128, n_classes),
|
||||
)
|
||||
|
||||
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
|
||||
"""
|
||||
Passe avant du réseau.
|
||||
|
||||
Args:
|
||||
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
|
||||
|
||||
Returns:
|
||||
Tensor (batch, 3) — logits bruts
|
||||
"""
|
||||
x = self.bloc1(x) # (batch, 32, 64, 64)
|
||||
x = self.bloc2(x) # (batch, 64, 32, 32)
|
||||
x = self.bloc3(x) # (batch, 128, 16, 16)
|
||||
x = self.bloc4(x) # (batch, 256, 1, 1)
|
||||
x = x.flatten(1) # (batch, 256)
|
||||
x = self.classifieur(x) # (batch, 3)
|
||||
return x
|
||||
|
||||
def predict_proba(self, x: 'torch.Tensor') -> np.ndarray:
|
||||
"""
|
||||
Calcule les probabilités softmax pour chaque classe.
|
||||
|
||||
Args:
|
||||
x: Tensor (batch, 3, 128, 128) — images RGB normalisées
|
||||
|
||||
Returns:
|
||||
np.ndarray de forme (batch, 3), valeurs dans [0, 1], somme = 1
|
||||
Ordre des colonnes : [P_SHORT, P_NEUTRAL, P_LONG]
|
||||
"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
logits = self.forward(x) # (batch, 3)
|
||||
probas = F.softmax(logits, dim=1) # (batch, 3)
|
||||
return probas.cpu().numpy().astype(np.float32)
|
||||
|
||||
else:
|
||||
# Stub si PyTorch n'est pas disponible
|
||||
class CandlestickCNN:
|
||||
"""
|
||||
Stub de CandlestickCNN — PyTorch non disponible.
|
||||
|
||||
Toutes les méthodes lèvent une RuntimeError explicite.
|
||||
"""
|
||||
|
||||
def __init__(self, n_classes: int = 3, dropout: float = 0.4):
|
||||
raise RuntimeError(
|
||||
"PyTorch non disponible. Installer torch>=2.0.0 pour utiliser CandlestickCNN."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
raise RuntimeError("PyTorch non disponible.")
|
||||
|
||||
def predict_proba(self, x):
|
||||
raise RuntimeError("PyTorch non disponible.")
|
||||
654
src/ml/cnn_image/cnn_image_strategy_model.py
Normal file
654
src/ml/cnn_image/cnn_image_strategy_model.py
Normal file
@@ -0,0 +1,654 @@
|
||||
"""
|
||||
CNNImageStrategyModel — Modèle CNN image-based pour la prédiction de signaux de trading.
|
||||
|
||||
Ce module entraîne un CNN Conv2D (CandlestickCNN) sur des images de graphiques
|
||||
en chandeliers pour prédire LONG (+1), NEUTRAL (0) ou SHORT (-1).
|
||||
|
||||
Pipeline :
|
||||
1. Génération de labels ATR-based (LabelGenerator partagé)
|
||||
2. Rendu des fenêtres OHLCV en images 128×128 RGB (CandlestickImageRenderer)
|
||||
3. Walk-forward (2 folds temporels) pour évaluation out-of-sample
|
||||
4. Entraînement sur toutes les données (Adam, CrossEntropyLoss weighted)
|
||||
5. Sauvegarde : EURUSD_1h.pt (state_dict) + EURUSD_1h_meta.json
|
||||
|
||||
Interface identique à MLStrategyModel et CNNStrategyModel :
|
||||
model = CNNImageStrategyModel(symbol='EURUSD', timeframe='1h')
|
||||
result = model.train(df_ohlcv)
|
||||
signal = model.predict(df) # {signal, confidence, probas, tradeable}
|
||||
model.save()
|
||||
model.load(symbol, timeframe)
|
||||
model.list_trained_models()
|
||||
model.get_feature_importance() # retourne [] (CNN = boîte noire)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.ml.features.label_generator import LabelGenerator
|
||||
from src.ml.cnn_image.chart_renderer import CandlestickImageRenderer, MPLFINANCE_AVAILABLE
|
||||
from src.ml.cnn_image.cnn_image_model import CandlestickCNN, TORCH_AVAILABLE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Répertoire de sauvegarde des modèles CNN image
|
||||
MODELS_DIR = Path(__file__).parent.parent.parent.parent / "models" / "cnn_image_strategy"
|
||||
|
||||
# Décodage des classes encodées vers les signaux de trading
|
||||
# Classe 0 → SHORT (-1), Classe 1 → NEUTRAL (0), Classe 2 → LONG (+1)
|
||||
CLASS_MAP = {0: -1, 1: 0, 2: 1}
|
||||
|
||||
# Imports conditionnels PyTorch
|
||||
if TORCH_AVAILABLE:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
SKLEARN_AVAILABLE = True
|
||||
else:
|
||||
torch = None
|
||||
nn = None
|
||||
DataLoader = None
|
||||
TensorDataset = None
|
||||
TimeSeriesSplit = None
|
||||
precision_recall_fscore_support = None
|
||||
SKLEARN_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from sklearn.model_selection import TimeSeriesSplit
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
SKLEARN_AVAILABLE = True
|
||||
except ImportError:
|
||||
SKLEARN_AVAILABLE = False
|
||||
|
||||
|
||||
class CNNImageStrategyModel:
|
||||
"""
|
||||
Modèle CNN image-based qui reconnaît les patterns visuels de trading.
|
||||
|
||||
Le modèle :
|
||||
- Convertit des bougies OHLCV en images 128×128 RGB (mplfinance)
|
||||
- Reconnaît visuellement : marteaux, doji, double top/bottom, supports/résistances
|
||||
- Prédit LONG (1) / SHORT (-1) / NEUTRAL (0)
|
||||
- Donne un score de confiance [0..1] par prédiction
|
||||
- Se sauvegarde en format PyTorch (.pt) pour rechargement sans ré-entraînement
|
||||
|
||||
Args:
|
||||
symbol: Paire tradée (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h', '15m')
|
||||
seq_len: Nombre de bougies par image (défaut 64)
|
||||
tp_atr_mult: Multiplicateur ATR pour le TP (génération labels)
|
||||
sl_atr_mult: Multiplicateur ATR pour le SL (génération labels)
|
||||
horizon: Nombre de barres pour évaluer TP/SL
|
||||
min_confidence: Seuil de confiance pour le signal tradeable
|
||||
epochs: Nombre maximum d'époques d'entraînement
|
||||
batch_size: Taille du batch (défaut 32)
|
||||
lr: Taux d'apprentissage Adam (défaut 1e-3)
|
||||
patience: Patience pour l'early stopping (défaut 7)
|
||||
"""
|
||||
|
||||
MODELS_DIR = MODELS_DIR
|
||||
CLASS_MAP = CLASS_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str = 'EURUSD',
|
||||
timeframe: str = '1h',
|
||||
seq_len: int = 64,
|
||||
tp_atr_mult: float = 2.0,
|
||||
sl_atr_mult: float = 1.0,
|
||||
horizon: int = 30,
|
||||
min_confidence: float = 0.55,
|
||||
epochs: int = 50,
|
||||
batch_size: int = 32,
|
||||
lr: float = 1e-3,
|
||||
patience: int = 7,
|
||||
):
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.seq_len = seq_len
|
||||
self.tp_atr_mult = tp_atr_mult
|
||||
self.sl_atr_mult = sl_atr_mult
|
||||
self.horizon = horizon
|
||||
self.min_confidence = min_confidence
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
self.lr = lr
|
||||
self.patience = patience
|
||||
|
||||
self.model: Optional['CandlestickCNN'] = None
|
||||
self.is_trained = False
|
||||
self.metadata: Dict = {}
|
||||
|
||||
self.MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def train(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Entraîne le modèle CNN image sur les données OHLCV fournies.
|
||||
|
||||
Étapes :
|
||||
1. Génération des labels ATR-based (LabelGenerator)
|
||||
2. Encodage des fenêtres en images (CandlestickImageRenderer)
|
||||
3. Walk-forward évaluation (2 folds temporels)
|
||||
4. Entraînement final sur toutes les données
|
||||
5. Sauvegarde automatique
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV (minimum 200 barres recommandées,
|
||||
colonnes : open/high/low/close/volume)
|
||||
|
||||
Returns:
|
||||
Dict avec métriques : wf_accuracy, wf_precision, label_dist,
|
||||
n_samples, trained_at, et error si échec.
|
||||
"""
|
||||
# Vérifications de disponibilité
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch non disponible — installer torch>=2.0.0'}
|
||||
if not SKLEARN_AVAILABLE:
|
||||
return {'error': 'scikit-learn non disponible'}
|
||||
|
||||
logger.info(
|
||||
f"Début entraînement CNNImageStrategyModel pour "
|
||||
f"{self.symbol}/{self.timeframe} — seq_len={self.seq_len}"
|
||||
)
|
||||
|
||||
# 1. Génération des labels (ATR-based, cohérent avec MLStrategyModel)
|
||||
gen = LabelGenerator(horizon=self.horizon)
|
||||
labels = gen.generate_atr_based(
|
||||
data,
|
||||
atr_tp_mult=self.tp_atr_mult,
|
||||
atr_sl_mult=self.sl_atr_mult,
|
||||
)
|
||||
|
||||
# Encodage [-1,0,1] → [0,1,2] pour CrossEntropyLoss
|
||||
y_enc = (labels + 1).values # np.ndarray
|
||||
|
||||
# 2. Encodage des images (fenêtres glissantes)
|
||||
logger.info(f" Rendu des images (seq_len={self.seq_len})...")
|
||||
renderer = CandlestickImageRenderer()
|
||||
X = renderer.encode(data, seq_len=self.seq_len) # (N, 3, 128, 128)
|
||||
|
||||
# 3. Alignement X et y
|
||||
# encode() produit N = len(data) - seq_len images
|
||||
# La i-ème image correspond à data.iloc[i : i + seq_len]
|
||||
# Le label associé est y_enc[i + seq_len - 1] (dernier point de la fenêtre)
|
||||
n_images = len(X)
|
||||
if n_images == 0:
|
||||
return {
|
||||
'error': f'Pas assez de données : {len(data)} barres < seq_len={self.seq_len}'
|
||||
}
|
||||
|
||||
# Index des labels pour chaque image : fenêtre i → label à l'indice i + seq_len - 1
|
||||
label_indices = np.arange(self.seq_len - 1, self.seq_len - 1 + n_images)
|
||||
# Supprimer les barres de fin (labels non fiables, horizon tronqué)
|
||||
valid_mask = label_indices < (len(y_enc) - self.horizon)
|
||||
|
||||
X = X[valid_mask]
|
||||
y = y_enc[label_indices[valid_mask]]
|
||||
|
||||
n_samples = len(X)
|
||||
if n_samples < 50:
|
||||
return {'error': f'Trop peu d\'échantillons valides : {n_samples}'}
|
||||
|
||||
logger.info(
|
||||
f" {n_samples} images valides — "
|
||||
f"LONG={(y==2).sum()}, SHORT={(y==0).sum()}, NEUTRAL={(y==1).sum()}"
|
||||
)
|
||||
|
||||
# 4. Calcul des class_weights (inversement proportionnels à la fréquence)
|
||||
class_weights = self._compute_class_weights(y)
|
||||
|
||||
# 5. Walk-forward évaluation (2 folds)
|
||||
wf_metrics = self._walk_forward_eval(X, y, class_weights, n_splits=2)
|
||||
|
||||
# 6. Entraînement final sur toutes les données
|
||||
logger.info(" Entraînement final sur toutes les données...")
|
||||
self.model = CandlestickCNN(n_classes=3)
|
||||
self._train_model(self.model, X, y, class_weights)
|
||||
self.is_trained = True
|
||||
|
||||
# Distribution des labels (décodage pour lisibilité)
|
||||
label_dist = {
|
||||
'long': int((y == 2).sum()),
|
||||
'short': int((y == 0).sum()),
|
||||
'neutral': int((y == 1).sum()),
|
||||
}
|
||||
|
||||
self.metadata = {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'seq_len': self.seq_len,
|
||||
'trained_at': datetime.utcnow().isoformat(),
|
||||
'n_samples': n_samples,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'label_dist': label_dist,
|
||||
'wf_metrics': wf_metrics,
|
||||
'mplfinance': MPLFINANCE_AVAILABLE,
|
||||
}
|
||||
|
||||
# 7. Sauvegarde
|
||||
self.save()
|
||||
|
||||
logger.info(
|
||||
f"Entraînement terminé. WF accuracy={wf_metrics.get('avg_accuracy', 0):.2%}"
|
||||
)
|
||||
return self.metadata
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Prédiction
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def predict(self, data: pd.DataFrame) -> Dict:
|
||||
"""
|
||||
Prédit le signal pour la dernière fenêtre disponible.
|
||||
|
||||
Args:
|
||||
data: DataFrame OHLCV récent (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Dict : {
|
||||
'signal': 1 (LONG) / -1 (SHORT) / 0 (NEUTRAL),
|
||||
'confidence': float [0..1] — max(P_LONG, P_SHORT),
|
||||
'probas': {'long': float, 'short': float, 'neutral': float},
|
||||
'tradeable': bool — confidence >= min_confidence et signal != 0,
|
||||
}
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'PyTorch non disponible'
|
||||
}
|
||||
if not self.is_trained or self.model is None:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': 'Modèle non entraîné'
|
||||
}
|
||||
|
||||
try:
|
||||
renderer = CandlestickImageRenderer()
|
||||
img = renderer.encode_last(data, seq_len=self.seq_len) # (1, 3, 128, 128)
|
||||
except ValueError as e:
|
||||
return {
|
||||
'signal': 0, 'confidence': 0.0, 'tradeable': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
# Conversion en Tensor
|
||||
x_tensor = torch.from_numpy(img).float() # (1, 3, 128, 128)
|
||||
|
||||
# Prédiction
|
||||
proba_arr = self.model.predict_proba(x_tensor) # (1, 3)
|
||||
proba = proba_arr[0] # (3,) : [P_SHORT, P_NEUTRAL, P_LONG]
|
||||
|
||||
pred_class = int(np.argmax(proba))
|
||||
signal = self.CLASS_MAP[pred_class] # Décodage : 0→-1, 1→0, 2→1
|
||||
|
||||
probas = {
|
||||
'short': float(proba[0]),
|
||||
'neutral': float(proba[1]),
|
||||
'long': float(proba[2]),
|
||||
}
|
||||
confidence = float(max(probas['long'], probas['short']))
|
||||
|
||||
return {
|
||||
'signal': signal,
|
||||
'confidence': confidence,
|
||||
'probas': probas,
|
||||
'tradeable': confidence >= self.min_confidence and signal != 0,
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sauvegarde / Chargement
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def save(self) -> Path:
|
||||
"""
|
||||
Sauvegarde le modèle et ses métadonnées sur disque.
|
||||
|
||||
Format :
|
||||
{symbol}_{timeframe}.pt — state_dict PyTorch
|
||||
{symbol}_{timeframe}_meta.json — métadonnées JSON
|
||||
|
||||
Returns:
|
||||
Path vers le fichier .pt sauvegardé
|
||||
|
||||
Raises:
|
||||
RuntimeError si le modèle n'est pas entraîné ou PyTorch indisponible
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
if not self.is_trained or self.model is None:
|
||||
raise RuntimeError("Modèle non entraîné — appeler train() avant save()")
|
||||
|
||||
model_id = f"{self.symbol}_{self.timeframe}"
|
||||
model_path = self.MODELS_DIR / f"{model_id}.pt"
|
||||
meta_path = self.MODELS_DIR / f"{model_id}_meta.json"
|
||||
|
||||
# Sauvegarde du state_dict + config pour pouvoir reconstruire le modèle
|
||||
torch.save(
|
||||
{
|
||||
'state_dict': self.model.state_dict(),
|
||||
'config': {
|
||||
'symbol': self.symbol,
|
||||
'timeframe': self.timeframe,
|
||||
'seq_len': self.seq_len,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
'horizon': self.horizon,
|
||||
'min_confidence': self.min_confidence,
|
||||
'epochs': self.epochs,
|
||||
'batch_size': self.batch_size,
|
||||
'lr': self.lr,
|
||||
'patience': self.patience,
|
||||
},
|
||||
'metadata': self.metadata,
|
||||
},
|
||||
model_path
|
||||
)
|
||||
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Modèle CNN image sauvegardé : {model_path}")
|
||||
return model_path
|
||||
|
||||
@classmethod
|
||||
def load(cls, symbol: str, timeframe: str) -> 'CNNImageStrategyModel':
|
||||
"""
|
||||
Charge un modèle existant depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (ex: 'EURUSD')
|
||||
timeframe: Timeframe (ex: '1h')
|
||||
|
||||
Returns:
|
||||
Instance CNNImageStrategyModel prête à prédire
|
||||
|
||||
Raises:
|
||||
RuntimeError si PyTorch non disponible
|
||||
FileNotFoundError si le modèle n'existe pas
|
||||
"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch non disponible")
|
||||
|
||||
model_path = MODELS_DIR / f"{symbol}_{timeframe}.pt"
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Modèle non trouvé : {model_path}")
|
||||
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
cfg = checkpoint.get('config', {})
|
||||
|
||||
instance = cls(
|
||||
symbol = cfg.get('symbol', symbol),
|
||||
timeframe = cfg.get('timeframe', timeframe),
|
||||
seq_len = cfg.get('seq_len', 64),
|
||||
tp_atr_mult = cfg.get('tp_atr_mult', 2.0),
|
||||
sl_atr_mult = cfg.get('sl_atr_mult', 1.0),
|
||||
horizon = cfg.get('horizon', 30),
|
||||
min_confidence = cfg.get('min_confidence', 0.55),
|
||||
epochs = cfg.get('epochs', 50),
|
||||
batch_size = cfg.get('batch_size', 32),
|
||||
lr = cfg.get('lr', 1e-3),
|
||||
patience = cfg.get('patience', 7),
|
||||
)
|
||||
|
||||
# Reconstruction du modèle et chargement des poids
|
||||
instance.model = CandlestickCNN(n_classes=3)
|
||||
instance.model.load_state_dict(checkpoint['state_dict'])
|
||||
instance.model.eval()
|
||||
instance.is_trained = True
|
||||
instance.metadata = checkpoint.get('metadata', {})
|
||||
|
||||
logger.info(f"Modèle CNN image chargé depuis {model_path}")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def list_trained_models() -> List[Dict]:
|
||||
"""
|
||||
Retourne la liste des modèles entraînés disponibles.
|
||||
|
||||
Returns:
|
||||
Liste de dicts avec symbol, timeframe, trained_at, n_samples, wf_accuracy
|
||||
"""
|
||||
if not MODELS_DIR.exists():
|
||||
return []
|
||||
|
||||
models = []
|
||||
for f in MODELS_DIR.glob("*_meta.json"):
|
||||
try:
|
||||
with open(f) as fp:
|
||||
meta = json.load(fp)
|
||||
# Nom du fichier : EURUSD_1h_meta.json → EURUSD, 1h
|
||||
stem = f.stem.replace('_meta', '') # ex: EURUSD_1h
|
||||
parts = stem.split('_', 1)
|
||||
models.append({
|
||||
'symbol': meta.get('symbol', parts[0] if parts else '?'),
|
||||
'timeframe': meta.get('timeframe', parts[1] if len(parts) > 1 else '?'),
|
||||
'trained_at': meta.get('trained_at', '?'),
|
||||
'n_samples': meta.get('n_samples', 0),
|
||||
'seq_len': meta.get('seq_len', 64),
|
||||
'wf_accuracy': meta.get('wf_metrics', {}).get('avg_accuracy', 0),
|
||||
'mplfinance': meta.get('mplfinance', False),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return models
|
||||
|
||||
def get_feature_importance(self) -> List[Dict]:
|
||||
"""
|
||||
Retourne l'importance des features.
|
||||
|
||||
Note : Les CNN sont des boîtes noires — pas d'importance de feature
|
||||
interprétable comme XGBoost. Retourne une liste vide.
|
||||
|
||||
Returns:
|
||||
[] — CNN = boîte noire visuelle
|
||||
"""
|
||||
return []
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Entraînement interne
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _train_model(
|
||||
self,
|
||||
model: 'CandlestickCNN',
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
class_weights: 'torch.Tensor',
|
||||
) -> None:
|
||||
"""
|
||||
Entraîne le modèle CNN sur les données fournies avec early stopping.
|
||||
|
||||
Args:
|
||||
model: Instance CandlestickCNN à entraîner
|
||||
X: Images (N, 3, 128, 128) float32
|
||||
y: Labels encodés (N,) int, valeurs {0, 1, 2}
|
||||
class_weights: Poids de classes (3,) pour CrossEntropyLoss
|
||||
"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model = model.to(device)
|
||||
class_weights = class_weights.to(device)
|
||||
|
||||
# Conversion en Tensors
|
||||
X_t = torch.from_numpy(X).float()
|
||||
y_t = torch.from_numpy(y.astype(np.int64))
|
||||
|
||||
# Split train/validation (80/20 temporel)
|
||||
n_train = int(len(X_t) * 0.8)
|
||||
X_tr, X_val = X_t[:n_train], X_t[n_train:]
|
||||
y_tr, y_val = y_t[:n_train], y_t[n_train:]
|
||||
|
||||
# DataLoaders
|
||||
train_ds = TensorDataset(X_tr, y_tr)
|
||||
val_ds = TensorDataset(X_val, y_val)
|
||||
train_dl = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False)
|
||||
val_dl = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Optimiseur et critère
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
|
||||
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
||||
|
||||
best_val_loss = float('inf')
|
||||
patience_count = 0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
# --- Phase entraînement ---
|
||||
model.train()
|
||||
train_loss = 0.0
|
||||
for X_batch, y_batch in train_dl:
|
||||
X_batch = X_batch.to(device)
|
||||
y_batch = y_batch.to(device)
|
||||
optimizer.zero_grad()
|
||||
logits = model(X_batch)
|
||||
loss = criterion(logits, y_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item() * len(X_batch)
|
||||
train_loss /= max(len(train_ds), 1)
|
||||
|
||||
# --- Phase validation ---
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for X_batch, y_batch in val_dl:
|
||||
X_batch = X_batch.to(device)
|
||||
y_batch = y_batch.to(device)
|
||||
logits = model(X_batch)
|
||||
loss = criterion(logits, y_batch)
|
||||
val_loss += loss.item() * len(X_batch)
|
||||
val_loss /= max(len(val_ds), 1)
|
||||
|
||||
logger.debug(
|
||||
f" Epoch {epoch+1}/{self.epochs} — "
|
||||
f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}"
|
||||
)
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_count = 0
|
||||
best_state = {
|
||||
k: v.cpu().clone() for k, v in model.state_dict().items()
|
||||
}
|
||||
else:
|
||||
patience_count += 1
|
||||
if patience_count >= self.patience:
|
||||
logger.info(
|
||||
f" Early stopping à l'époque {epoch+1} "
|
||||
f"(patience={self.patience}, best_val_loss={best_val_loss:.4f})"
|
||||
)
|
||||
break
|
||||
|
||||
# Restauration des meilleurs poids
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
model.eval()
|
||||
model.to('cpu')
|
||||
|
||||
def _walk_forward_eval(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
class_weights: 'torch.Tensor',
|
||||
n_splits: int = 2,
|
||||
) -> Dict:
|
||||
"""
|
||||
Évalue le modèle en cross-validation temporelle (walk-forward).
|
||||
|
||||
Args:
|
||||
X: Images (N, 3, 128, 128)
|
||||
y: Labels encodés (N,)
|
||||
class_weights: Poids de classes pour CrossEntropyLoss
|
||||
n_splits: Nombre de folds temporels
|
||||
|
||||
Returns:
|
||||
Dict : avg_accuracy, avg_precision, fold_accuracies
|
||||
"""
|
||||
tscv = TimeSeriesSplit(n_splits=n_splits)
|
||||
accuracies = []
|
||||
precisions = []
|
||||
|
||||
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
|
||||
X_tr, X_te = X[train_idx], X[test_idx]
|
||||
y_tr, y_te = y[train_idx], y[test_idx]
|
||||
|
||||
# Modèle temporaire pour ce fold
|
||||
fold_model = CandlestickCNN(n_classes=3)
|
||||
self._train_model(fold_model, X_tr, y_tr, class_weights)
|
||||
|
||||
# Évaluation
|
||||
x_tensor = torch.from_numpy(X_te).float()
|
||||
proba_arr = fold_model.predict_proba(x_tensor) # (N_te, 3)
|
||||
y_pred = np.argmax(proba_arr, axis=1)
|
||||
|
||||
acc = float((y_pred == y_te).mean())
|
||||
accuracies.append(acc)
|
||||
|
||||
# Précision sur signaux directionnels (LONG=2 et SHORT=0, exclure NEUTRAL=1)
|
||||
mask = (y_te != 1) | (y_pred != 1)
|
||||
if mask.sum() > 0 and SKLEARN_AVAILABLE:
|
||||
prec, _, _, _ = precision_recall_fscore_support(
|
||||
y_te[mask], y_pred[mask],
|
||||
average='macro', zero_division=0
|
||||
)
|
||||
precisions.append(float(prec))
|
||||
else:
|
||||
precisions.append(0.0)
|
||||
|
||||
logger.info(
|
||||
f" Fold {fold+1}/{n_splits} : acc={acc:.2%}, "
|
||||
f"prec={precisions[-1]:.2%}"
|
||||
)
|
||||
|
||||
return {
|
||||
'avg_accuracy': float(np.mean(accuracies)) if accuracies else 0.0,
|
||||
'avg_precision': float(np.mean(precisions)) if precisions else 0.0,
|
||||
'fold_accuracies': [float(a) for a in accuracies],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _compute_class_weights(y: np.ndarray) -> 'torch.Tensor':
|
||||
"""
|
||||
Calcule les poids de classes inversement proportionnels à leur fréquence.
|
||||
|
||||
Permet de compenser les déséquilibres (ex: beaucoup de NEUTRAL, peu de trades).
|
||||
|
||||
Args:
|
||||
y: Labels encodés {0, 1, 2}
|
||||
|
||||
Returns:
|
||||
torch.Tensor de forme (3,) — poids pour CrossEntropyLoss
|
||||
"""
|
||||
weights = np.ones(3, dtype=np.float32)
|
||||
n_total = len(y)
|
||||
|
||||
if n_total == 0:
|
||||
return torch.from_numpy(weights)
|
||||
|
||||
for cls in range(3):
|
||||
count = (y == cls).sum()
|
||||
if count > 0:
|
||||
# Poids inversement proportionnel : classes rares → poids plus élevé
|
||||
weights[cls] = n_total / (3.0 * count)
|
||||
|
||||
# Normalisation pour que la somme des poids = 3
|
||||
weights = weights / weights.mean()
|
||||
|
||||
return torch.from_numpy(weights)
|
||||
@@ -34,9 +34,10 @@ class EnsembleModel:
|
||||
"""
|
||||
|
||||
DEFAULT_WEIGHTS = {
|
||||
'xgboost': 0.40,
|
||||
'cnn': 0.60,
|
||||
'rl': 0.00, # Réservé Phase 4d
|
||||
'xgboost': 0.30,
|
||||
'cnn': 0.30,
|
||||
'cnn_image': 0.40, # CNN Vision — favorisé (patterns visuels bruts)
|
||||
'rl': 0.00, # Réservé Phase 4d
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -52,6 +53,9 @@ class EnsembleModel:
|
||||
# Modèles attachés (duck typing : .predict(df), .is_trained)
|
||||
self._models: Dict[str, Any] = {}
|
||||
|
||||
# Référence directe au modèle CNN Image (accès rapide pour auto-attach)
|
||||
self._cnn_image_model = None
|
||||
|
||||
logger.info(
|
||||
f"EnsembleModel initialisé — poids={self.weights}, "
|
||||
f"seuil={self.min_confidence}, accord_requis={self.require_agreement}"
|
||||
@@ -68,6 +72,11 @@ class EnsembleModel:
|
||||
"""Attache un CNNStrategyModel entraîné."""
|
||||
self._attach('cnn', model)
|
||||
|
||||
def attach_cnn_image(self, model) -> None:
|
||||
"""Attache un CNNImageStrategyModel entraîné (Phase 4c-bis)."""
|
||||
self._cnn_image_model = model
|
||||
self._attach('cnn_image', model)
|
||||
|
||||
def attach_rl(self, model) -> None:
|
||||
"""Attache un agent RL (Phase 4d)."""
|
||||
self._attach('rl', model)
|
||||
@@ -198,12 +207,13 @@ class EnsembleModel:
|
||||
# Statut et configuration
|
||||
# ------------------------------------------------------------------
|
||||
def is_ready(self) -> bool:
|
||||
"""True si au moins 2 modèles sont attachés et entraînés."""
|
||||
"""True si au moins 2 modèles parmi {xgboost, cnn, cnn_image} sont attachés et entraînés."""
|
||||
eligible_names = {'xgboost', 'cnn', 'cnn_image'}
|
||||
trained = sum(
|
||||
1 for m in self._models.values()
|
||||
if m.is_trained and self.weights.get(
|
||||
next(k for k, v in self._models.items() if v is m), 0
|
||||
) > 0
|
||||
1 for name, model in self._models.items()
|
||||
if name in eligible_names
|
||||
and model.is_trained
|
||||
and self.weights.get(name, 0.0) > 0
|
||||
)
|
||||
return trained >= 2
|
||||
|
||||
|
||||
3
src/strategies/cnn_image_driven/__init__.py
Normal file
3
src/strategies/cnn_image_driven/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.strategies.cnn_image_driven.cnn_image_strategy import CNNImageDrivenStrategy
|
||||
|
||||
__all__ = ['CNNImageDrivenStrategy']
|
||||
258
src/strategies/cnn_image_driven/cnn_image_strategy.py
Normal file
258
src/strategies/cnn_image_driven/cnn_image_strategy.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
CNN Image-Driven Strategy — Stratégie pilotée par un CNN à vision par ordinateur.
|
||||
|
||||
Contrairement au CNN 1D qui analyse des séquences numériques brutes, cette
|
||||
stratégie convertit les bougies OHLCV en images de graphiques (rendu mplfinance
|
||||
128×128 RGB), puis utilise un réseau Conv2D pour reconnaître les patterns
|
||||
visuels — exactement comme un trader devant TradingView.
|
||||
|
||||
Patterns détectés sans feature engineering explicite :
|
||||
- Bougies marteau, étoile filante, doji
|
||||
- Double top, double bottom
|
||||
- Rebonds sur supports/résistances visuels
|
||||
- Momentum : grande bougie pleine après consolidation
|
||||
- Divergences prix/volume visibles
|
||||
|
||||
Fonctionnement :
|
||||
1. Le modèle CNN-Image est chargé depuis le disque (entraîné via POST /trading/train-cnn-image)
|
||||
2. À chaque barre, les seq_len dernières bougies sont rendues en image 128×128
|
||||
3. Le modèle prédit LONG / SHORT / NEUTRAL avec un score de confiance
|
||||
4. Si confidence >= min_confidence, un signal est émis avec SL/TP basés sur ATR
|
||||
|
||||
Intégration :
|
||||
- Compatible avec StrategyEngine (même interface que CNNDrivenStrategy)
|
||||
- Chargé automatiquement si un modèle entraîné existe pour le symbole/timeframe
|
||||
- Requiert PyTorch + mplfinance + Pillow
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.strategies.base_strategy import BaseStrategy, Signal, StrategyConfig
|
||||
|
||||
try:
|
||||
from src.ml.cnn_image import CNNImageStrategyModel
|
||||
CNN_IMAGE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CNN_IMAGE_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CNNImageDrivenStrategy(BaseStrategy):
|
||||
"""
|
||||
Stratégie de trading pilotée par un CNN Conv2D pré-entraîné sur images de graphiques.
|
||||
|
||||
Le modèle apprend les patterns visuels des chandeliers directement depuis
|
||||
des images 128×128 RGB sans features pré-calculées :
|
||||
- Double bottom / double top
|
||||
- Configurations chandeliers (marteau, doji, étoile filante)
|
||||
- Rebonds sur zones de support/résistance visibles
|
||||
- Structures de momentum multi-barres
|
||||
- Divergences volume/prix visuelles
|
||||
|
||||
Args:
|
||||
config: Dict de configuration (timeframe, risk_per_trade, symbol, etc.)
|
||||
|
||||
Config keys supplémentaires (optionnelles) :
|
||||
min_confidence: Seuil de confiance minimum [0..1] (défaut: 0.55)
|
||||
tp_atr_mult: Multiplicateur ATR pour TP (défaut: 2.0)
|
||||
sl_atr_mult: Multiplicateur ATR pour SL (défaut: 1.0)
|
||||
seq_len: Nombre de bougies par image (défaut: 64)
|
||||
auto_load: Charger automatiquement le modèle existant (défaut: True)
|
||||
"""
|
||||
|
||||
STRATEGY_NAME = 'cnn_image_driven'
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(config)
|
||||
|
||||
self.symbol = config.get('symbol', 'EURUSD')
|
||||
self.min_confidence = config.get('min_confidence', 0.55)
|
||||
self.tp_atr_mult = config.get('tp_atr_mult', 2.0)
|
||||
self.sl_atr_mult = config.get('sl_atr_mult', 1.0)
|
||||
self.seq_len = config.get('seq_len', 64)
|
||||
|
||||
self.cnn_image_model: Optional['CNNImageStrategyModel'] = None
|
||||
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
|
||||
return
|
||||
|
||||
# Tentative de chargement automatique du modèle existant
|
||||
if config.get('auto_load', True):
|
||||
self._try_load_model()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Interface BaseStrategy
|
||||
# -------------------------------------------------------------------------
|
||||
def analyze(self, market_data: pd.DataFrame) -> Optional[Signal]:
|
||||
"""
|
||||
Génère un signal de trading via le modèle CNN Image.
|
||||
|
||||
Args:
|
||||
market_data: DataFrame OHLCV (minimum seq_len barres)
|
||||
|
||||
Returns:
|
||||
Signal si le modèle est confiant, None sinon
|
||||
"""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
logger.debug("CNN Image Strategy : PyTorch/mplfinance non disponible, aucun signal")
|
||||
return None
|
||||
|
||||
if self.cnn_image_model is None or not self.cnn_image_model.is_trained:
|
||||
logger.debug("CNN Image Strategy : modèle non chargé, aucun signal")
|
||||
return None
|
||||
|
||||
if len(market_data) < self.seq_len:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = self.cnn_image_model.predict(market_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN Image Strategy predict error : {e}")
|
||||
return None
|
||||
|
||||
if not result.get('tradeable', False):
|
||||
return None
|
||||
|
||||
signal_dir = result['signal'] # 1 = LONG, -1 = SHORT
|
||||
confidence = result['confidence']
|
||||
|
||||
# Prix et ATR pour SL/TP
|
||||
last_close = float(market_data['close'].iloc[-1])
|
||||
atr = self._compute_atr(market_data)
|
||||
if atr <= 0:
|
||||
return None
|
||||
|
||||
if signal_dir == 1:
|
||||
direction = 'LONG'
|
||||
stop_loss = last_close - self.sl_atr_mult * atr
|
||||
take_profit = last_close + self.tp_atr_mult * atr
|
||||
elif signal_dir == -1:
|
||||
direction = 'SHORT'
|
||||
stop_loss = last_close + self.sl_atr_mult * atr
|
||||
take_profit = last_close - self.tp_atr_mult * atr
|
||||
else:
|
||||
return None
|
||||
|
||||
signal = Signal(
|
||||
symbol = self.symbol,
|
||||
direction = direction,
|
||||
entry_price = last_close,
|
||||
stop_loss = stop_loss,
|
||||
take_profit = take_profit,
|
||||
confidence = confidence,
|
||||
timestamp = datetime.now(timezone.utc),
|
||||
strategy = self.STRATEGY_NAME,
|
||||
metadata = {
|
||||
'probas': result.get('probas', {}),
|
||||
'seq_len': self.seq_len,
|
||||
'atr': atr,
|
||||
'tp_atr_mult': self.tp_atr_mult,
|
||||
'sl_atr_mult': self.sl_atr_mult,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"CNN Image Signal : {direction} {self.symbol} | "
|
||||
f"entry={last_close:.5f} SL={stop_loss:.5f} TP={take_profit:.5f} | "
|
||||
f"confidence={confidence:.2%}"
|
||||
)
|
||||
return signal
|
||||
|
||||
def calculate_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Retourne les données telles quelles — le CNN Image travaille sur des rendus d'images."""
|
||||
return data
|
||||
|
||||
def update_params(self, params: Dict) -> None:
|
||||
"""Mise à jour dynamique des paramètres (depuis API ou Optuna)."""
|
||||
if 'min_confidence' in params:
|
||||
self.min_confidence = params['min_confidence']
|
||||
if self.cnn_image_model:
|
||||
self.cnn_image_model.min_confidence = params['min_confidence']
|
||||
if 'tp_atr_mult' in params:
|
||||
self.tp_atr_mult = params['tp_atr_mult']
|
||||
if 'sl_atr_mult' in params:
|
||||
self.sl_atr_mult = params['sl_atr_mult']
|
||||
if 'seq_len' in params:
|
||||
self.seq_len = params['seq_len']
|
||||
logger.info(f"CNN Image Strategy params mis à jour : {params}")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Gestion du modèle
|
||||
# -------------------------------------------------------------------------
|
||||
def load_model(self, symbol: Optional[str] = None, timeframe: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Charge un modèle CNN Image depuis le disque.
|
||||
|
||||
Args:
|
||||
symbol: Paire (défaut: self.symbol)
|
||||
timeframe: Timeframe (défaut: self.config.timeframe)
|
||||
|
||||
Returns:
|
||||
True si chargement réussi
|
||||
"""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
logger.warning("CNN Image non disponible (PyTorch/mplfinance requis)")
|
||||
return False
|
||||
|
||||
sym = symbol or self.symbol
|
||||
tf = timeframe or self.config.timeframe
|
||||
try:
|
||||
self.cnn_image_model = CNNImageStrategyModel.load(sym, tf)
|
||||
logger.info(f"Modèle CNN Image chargé : {sym}/{tf}")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.info(f"Aucun modèle CNN Image trouvé pour {sym}/{tf}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur chargement modèle CNN Image : {e}")
|
||||
return False
|
||||
|
||||
def attach_model(self, model: 'CNNImageStrategyModel') -> None:
|
||||
"""Attache directement un modèle CNN Image (après entraînement via API)."""
|
||||
self.cnn_image_model = model
|
||||
self.symbol = model.symbol
|
||||
logger.info(f"Modèle CNN Image attaché : {model.symbol}/{model.timeframe}")
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Retourne True si le modèle CNN Image est chargé et entraîné."""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
return False
|
||||
return self.cnn_image_model is not None and self.cnn_image_model.is_trained
|
||||
|
||||
def get_model_info(self) -> Dict:
|
||||
"""Retourne les métadonnées du modèle CNN Image actif."""
|
||||
if not CNN_IMAGE_AVAILABLE:
|
||||
return {'status': 'PyTorch/mplfinance non disponible'}
|
||||
if not self.is_ready():
|
||||
return {'status': 'non entraîné'}
|
||||
meta = self.cnn_image_model.metadata.copy()
|
||||
meta['is_ready'] = True
|
||||
meta['seq_len'] = self.seq_len
|
||||
return meta
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _try_load_model(self) -> None:
|
||||
"""Tente un chargement silencieux du modèle au démarrage."""
|
||||
try:
|
||||
self.load_model()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_atr(df: pd.DataFrame, period: int = 14) -> float:
|
||||
"""Calcule l'ATR moyen sur les dernières barres."""
|
||||
if len(df) < period + 1:
|
||||
return float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
h, l, pc = df['high'], df['low'], df['close'].shift(1)
|
||||
tr = pd.concat([h - l, (h - pc).abs(), (l - pc).abs()], axis=1).max(axis=1)
|
||||
atr = tr.rolling(period).mean().iloc[-1]
|
||||
return float(atr) if not np.isnan(atr) else float(df['high'].iloc[-1] - df['low'].iloc[-1])
|
||||
Reference in New Issue
Block a user