auto-trade skeleton added
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 4m10s
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 4m10s
This commit is contained in:
447
auto_trade.py
Normal file
447
auto_trade.py
Normal file
@ -0,0 +1,447 @@
|
||||
from typing import Dict, List
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pymongo import MongoClient
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
import pytz
|
||||
|
||||
|
||||
class CryptoTradingStrategy:
|
||||
def __init__(
|
||||
self,
|
||||
sma_short: int = 20,
|
||||
sma_long: int = 50,
|
||||
rsi_base_period: int = 14,
|
||||
rsi_base_overbought: float = 70,
|
||||
rsi_base_oversold: float = 30,
|
||||
volume_threshold: float = 1.5,
|
||||
stop_loss: float = 0.02,
|
||||
take_profit: float = 0.035,
|
||||
):
|
||||
self.sma_short = sma_short
|
||||
self.sma_long = sma_long
|
||||
self.rsi_base_period = rsi_base_period
|
||||
self.rsi_base_overbought = rsi_base_overbought
|
||||
self.rsi_base_oversold = rsi_base_oversold
|
||||
self.volume_threshold = volume_threshold
|
||||
self.stop_loss = stop_loss
|
||||
self.take_profit = take_profit
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Calculate technical indicators."""
|
||||
# Calculate SMAs
|
||||
df["SMA_short"] = df["close"].rolling(window=self.sma_short).mean()
|
||||
df["SMA_long"] = df["close"].rolling(window=self.sma_long).mean()
|
||||
|
||||
# Calculate ATR (Average True Range) for volatility
|
||||
df["TR"] = (df["high"] - df["low"]).abs() # True Range = High - Low
|
||||
df["ATR"] = df["TR"].rolling(window=14).mean() # 14-period ATR
|
||||
|
||||
# Adaptive RSI period based on volatility
|
||||
# Adaptive RSI period based on volatility
|
||||
df["RSI_period"] = self.rsi_base_period * (
|
||||
1 + (df["ATR"] / df["close"].rolling(window=20).mean())
|
||||
)
|
||||
|
||||
# Replace NaN or infinite values with a default RSI period
|
||||
df["RSI_period"].fillna(self.rsi_base_period, inplace=True)
|
||||
df["RSI_period"].replace([np.inf, -np.inf], self.rsi_base_period, inplace=True)
|
||||
|
||||
# Convert to integer
|
||||
df["RSI_period"] = df["RSI_period"].astype(int)
|
||||
# Calculate adaptive RSI
|
||||
df["RSI"] = 0.0
|
||||
for period in df["RSI_period"].unique():
|
||||
mask = df["RSI_period"] == period
|
||||
delta = df.loc[mask, "close"].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
df.loc[mask, "RSI"] = 100 - (100 / (1 + rs))
|
||||
|
||||
# Adaptive RSI thresholds based on Bollinger Band width
|
||||
df["BB_width"] = (df["close"].rolling(window=20).std() * 2) / df[
|
||||
"close"
|
||||
].rolling(window=20).mean()
|
||||
df["RSI_overbought"] = self.rsi_base_overbought + (df["BB_width"] * 20)
|
||||
df["RSI_oversold"] = self.rsi_base_oversold - (df["BB_width"] * 20)
|
||||
|
||||
# Volume analysis
|
||||
df["Volume_MA"] = df["volume"].rolling(window=20).mean()
|
||||
df["Volume_Ratio"] = df["volume"] / df["Volume_MA"]
|
||||
|
||||
return df
|
||||
|
||||
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Generate trading signals based on indicators."""
|
||||
df["signal"] = 0 # 0: hold, 1: buy, -1: sell
|
||||
|
||||
# Generate buy signals
|
||||
buy_conditions = (
|
||||
(df["SMA_short"] > df["SMA_long"]) # Golden cross
|
||||
& (df["RSI"] < df["RSI_oversold"]) # Adaptive oversold condition
|
||||
& (df["Volume_Ratio"] > self.volume_threshold) # High volume
|
||||
)
|
||||
df.loc[buy_conditions, "signal"] = 1
|
||||
|
||||
# Generate sell signals
|
||||
sell_conditions = (df["SMA_short"] < df["SMA_long"]) | ( # Death cross
|
||||
df["RSI"] > df["RSI_overbought"]
|
||||
) # Adaptive overbought condition
|
||||
df.loc[sell_conditions, "signal"] = -1
|
||||
|
||||
return df
|
||||
|
||||
def apply_risk_management(self, df: pd.DataFrame, position: Dict) -> Dict:
|
||||
"""Apply risk management rules to current position."""
|
||||
current_price = df["close"].iloc[-1]
|
||||
|
||||
if position["in_position"]:
|
||||
# Check stop loss
|
||||
if current_price <= position["entry_price"] * (1 - self.stop_loss):
|
||||
position["should_exit"] = True
|
||||
position["exit_reason"] = "stop_loss"
|
||||
|
||||
# Check take profit
|
||||
elif current_price >= position["entry_price"] * (1 + self.take_profit):
|
||||
position["should_exit"] = True
|
||||
position["exit_reason"] = "take_profit"
|
||||
|
||||
return position
|
||||
|
||||
def execute_trades(self, df: pd.DataFrame) -> List[Dict]:
|
||||
"""Execute trading strategy and maintain positions."""
|
||||
trades = []
|
||||
position = {
|
||||
"in_position": False,
|
||||
"entry_price": 0,
|
||||
"entry_time": None,
|
||||
"should_exit": False,
|
||||
"exit_reason": None,
|
||||
}
|
||||
|
||||
for i in range(len(df)):
|
||||
current_data = df.iloc[i]
|
||||
|
||||
# Update position status
|
||||
if position["in_position"]:
|
||||
position = self.apply_risk_management(
|
||||
df.iloc[max(0, i - 10) : i + 1], position
|
||||
)
|
||||
|
||||
# Exit position if necessary
|
||||
if position["in_position"] and (
|
||||
position["should_exit"] or current_data["signal"] == -1
|
||||
):
|
||||
trade = {
|
||||
"exit_time": current_data.name,
|
||||
"exit_price": current_data["close"],
|
||||
"exit_reason": position["exit_reason"] or "signal",
|
||||
"profit_pct": (current_data["close"] - position["entry_price"])
|
||||
/ position["entry_price"]
|
||||
* 100,
|
||||
}
|
||||
trades.append({**position, **trade})
|
||||
position = {
|
||||
"in_position": False,
|
||||
"entry_price": 0,
|
||||
"entry_time": None,
|
||||
"should_exit": False,
|
||||
"exit_reason": None,
|
||||
}
|
||||
|
||||
# Enter new position
|
||||
elif not position["in_position"] and current_data["signal"] == 1:
|
||||
position = {
|
||||
"in_position": True,
|
||||
"entry_price": current_data["close"],
|
||||
"entry_time": current_data.name,
|
||||
"should_exit": False,
|
||||
"exit_reason": None,
|
||||
}
|
||||
|
||||
return trades
|
||||
|
||||
def run_strategy(self, df: pd.DataFrame) -> tuple:
|
||||
"""Run the complete trading strategy."""
|
||||
# Prepare data
|
||||
df = self.calculate_indicators(df)
|
||||
df = self.generate_signals(df)
|
||||
|
||||
# Execute trades
|
||||
trades = self.execute_trades(df)
|
||||
|
||||
# Calculate performance metrics
|
||||
total_trades = len(trades)
|
||||
winning_trades = len([t for t in trades if t["profit_pct"] > 0])
|
||||
total_return = sum(t["profit_pct"] for t in trades)
|
||||
|
||||
metrics = {
|
||||
"total_trades": total_trades,
|
||||
"winning_trades": winning_trades,
|
||||
"win_rate": winning_trades / total_trades if total_trades > 0 else 0,
|
||||
"total_return": total_return,
|
||||
"average_return": total_return / total_trades if total_trades > 0 else 0,
|
||||
}
|
||||
|
||||
return trades, metrics
|
||||
|
||||
|
||||
class LiveCryptoTradingStrategy(CryptoTradingStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
mongodb_uri: str,
|
||||
db_name: str,
|
||||
initial_balance: float = 10000,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client = MongoClient(mongodb_uri)
|
||||
self.db = self.client[db_name]
|
||||
self.position_collection = self.db["positions"]
|
||||
self.trade_collection = self.db["trades"]
|
||||
self.data_collection = self.db["market_data"]
|
||||
self.asset_collection = self.db["assets"]
|
||||
|
||||
# Initialize assets if not exists
|
||||
if not self.asset_collection.find_one({"asset_type": "USD"}):
|
||||
self.asset_collection.insert_one(
|
||||
{
|
||||
"asset_type": "USD",
|
||||
"amount": initial_balance,
|
||||
"last_updated": datetime.now(pytz.UTC),
|
||||
}
|
||||
)
|
||||
|
||||
def get_balance(self, asset_type: str = "USD") -> float:
|
||||
"""Get current balance of specified asset"""
|
||||
asset_doc = self.asset_collection.find_one({"asset_type": asset_type})
|
||||
return asset_doc["amount"] if asset_doc else 0
|
||||
|
||||
def update_balance(self, asset_type: str, amount: float):
|
||||
"""Update balance of specified asset"""
|
||||
self.asset_collection.update_one(
|
||||
{"asset_type": asset_type},
|
||||
{"$set": {"amount": amount, "last_updated": datetime.now(pytz.UTC)}},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def get_latest_data(self, symbol: str, lookback_minutes: int = 200) -> pd.DataFrame:
|
||||
"""Fetch latest data from MongoDB"""
|
||||
end_time = datetime.now(pytz.UTC)
|
||||
start_time = end_time - timedelta(minutes=lookback_minutes)
|
||||
|
||||
cursor = self.data_collection.find(
|
||||
{"symbol": symbol, "timestamp": {"$gte": start_time, "$lte": end_time}}
|
||||
).sort("timestamp", 1)
|
||||
|
||||
data = list(cursor)
|
||||
if not data:
|
||||
return None
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.set_index("timestamp", inplace=True)
|
||||
return df
|
||||
|
||||
def execute_trade(self, symbol: str, side: str, amount: float, price: float):
|
||||
"""Execute trade and update balances"""
|
||||
base_asset = symbol[:-4] # BTCUSDT -> BTC
|
||||
quote_asset = symbol[-4:] # BTCUSDT -> USDT
|
||||
|
||||
if side == "buy":
|
||||
# Check if enough quote asset (USDT) available
|
||||
quote_balance = self.get_balance(quote_asset)
|
||||
cost = amount * price
|
||||
|
||||
if quote_balance >= cost:
|
||||
# Update quote asset balance (decrease)
|
||||
self.update_balance(quote_asset, quote_balance - cost)
|
||||
# Update base asset balance (increase)
|
||||
base_balance = self.get_balance(base_asset)
|
||||
self.update_balance(base_asset, base_balance + amount)
|
||||
return True
|
||||
return False
|
||||
|
||||
elif side == "sell":
|
||||
# Check if enough base asset available
|
||||
base_balance = self.get_balance(base_asset)
|
||||
|
||||
if base_balance >= amount:
|
||||
# Update base asset balance (decrease)
|
||||
self.update_balance(base_asset, base_balance - amount)
|
||||
# Update quote asset balance (increase)
|
||||
quote_balance = self.get_balance(quote_asset)
|
||||
self.update_balance(quote_asset, quote_balance + (amount * price))
|
||||
return True
|
||||
return False
|
||||
|
||||
def update_position(self, symbol: str, position: Dict):
|
||||
"""Update position in MongoDB"""
|
||||
self.position_collection.update_one(
|
||||
{"symbol": symbol},
|
||||
{
|
||||
"$set": {
|
||||
"in_position": position["in_position"],
|
||||
"entry_price": position["entry_price"],
|
||||
"entry_time": position["entry_time"],
|
||||
"position_type": position.get("position_type", "long"),
|
||||
"last_updated": datetime.now(pytz.UTC),
|
||||
}
|
||||
},
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
def log_trade(self, symbol: str, trade: Dict):
|
||||
"""Log completed trade to MongoDB"""
|
||||
trade_doc = {
|
||||
"symbol": symbol,
|
||||
"entry_time": trade["entry_time"],
|
||||
"exit_time": trade["exit_time"],
|
||||
"entry_price": trade["entry_price"],
|
||||
"exit_price": trade["exit_price"],
|
||||
"position_type": trade.get("position_type", "long"),
|
||||
"profit_pct": trade["profit_pct"],
|
||||
"exit_reason": trade["exit_reason"],
|
||||
}
|
||||
self.trade_collection.insert_one(trade_doc)
|
||||
|
||||
def run_live_strategy(
|
||||
self, symbol: str, trade_amount: float = 0.1, interval_seconds: int = 60
|
||||
):
|
||||
"""Run strategy in live mode"""
|
||||
self.logger.info(f"Starting live trading for {symbol}")
|
||||
|
||||
# Get or initialize position
|
||||
position_doc = self.position_collection.find_one({"symbol": symbol})
|
||||
position = {
|
||||
"in_position": position_doc["in_position"] if position_doc else False,
|
||||
"entry_price": position_doc.get("entry_price", 0),
|
||||
"entry_time": position_doc.get("entry_time"),
|
||||
"should_exit": False,
|
||||
"exit_reason": None,
|
||||
"position_type": position_doc.get("position_type", "long"),
|
||||
}
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get latest data
|
||||
df = self.get_latest_data(symbol)
|
||||
if df is None or len(df) < self.sma_long:
|
||||
self.logger.warning("Insufficient data, waiting...")
|
||||
time.sleep(interval_seconds)
|
||||
continue
|
||||
|
||||
# Calculate indicators and signals
|
||||
df = self.calculate_indicators(df)
|
||||
df = self.generate_signals(df)
|
||||
|
||||
current_data = df.iloc[-1]
|
||||
|
||||
# Update position status
|
||||
if position["in_position"]:
|
||||
position = self.apply_risk_management(df.iloc[-10:], position)
|
||||
|
||||
# Handle position exit
|
||||
if position["in_position"] and (
|
||||
position["should_exit"] or current_data["signal"] == -1
|
||||
):
|
||||
if self.execute_trade(
|
||||
symbol, "sell", trade_amount, current_data["close"]
|
||||
):
|
||||
# Log trade and update position as before
|
||||
trade = {
|
||||
"exit_time": current_data.name,
|
||||
"exit_price": current_data["close"],
|
||||
"exit_reason": position["exit_reason"] or "signal",
|
||||
"profit_pct": (
|
||||
current_data["close"] - position["entry_price"]
|
||||
)
|
||||
/ position["entry_price"]
|
||||
* 100,
|
||||
"position_type": position["position_type"],
|
||||
"amount": trade_amount,
|
||||
"usd_value": trade_amount * current_data["close"],
|
||||
}
|
||||
self.log_trade(symbol, {**position, **trade})
|
||||
|
||||
# Update position
|
||||
position = {
|
||||
"in_position": False,
|
||||
"entry_price": 0,
|
||||
"entry_time": None,
|
||||
"should_exit": False,
|
||||
"exit_reason": None,
|
||||
}
|
||||
self.update_position(symbol, position)
|
||||
|
||||
# Log balances
|
||||
self.logger.info(
|
||||
f"Exited position. New balances: "
|
||||
f"USD: {self.get_balance('USDT')}, "
|
||||
f"{symbol[:-4]}: {self.get_balance(symbol[:-4])}"
|
||||
)
|
||||
|
||||
# Handle position entry
|
||||
elif not position["in_position"] and current_data["signal"] == 1:
|
||||
if self.execute_trade(
|
||||
symbol, "buy", trade_amount, current_data["close"]
|
||||
):
|
||||
position = {
|
||||
"in_position": True,
|
||||
"entry_price": current_data["close"],
|
||||
"entry_time": current_data.name,
|
||||
"should_exit": False,
|
||||
"exit_reason": None,
|
||||
"position_type": "long",
|
||||
"amount": trade_amount,
|
||||
}
|
||||
self.update_position(symbol, position)
|
||||
|
||||
# Log balances
|
||||
self.logger.info(
|
||||
f"Entered position. New balances: "
|
||||
f"USD: {self.get_balance('USDT')}, "
|
||||
f"{symbol[:-4]}: {self.get_balance(symbol[:-4])}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in live trading loop: {str(e)}")
|
||||
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
|
||||
def run_live_trading(mongodb_uri: str, symbols: List[str]):
|
||||
"""Run live trading for multiple symbols"""
|
||||
for symbol in symbols:
|
||||
strategy = LiveCryptoTradingStrategy(
|
||||
mongodb_uri=mongodb_uri,
|
||||
db_name="crypto_trading",
|
||||
sma_short=20,
|
||||
sma_long=100,
|
||||
rsi_base_period=21,
|
||||
rsi_base_overbought=70,
|
||||
rsi_base_oversold=30,
|
||||
volume_threshold=1.0,
|
||||
stop_loss=0.01,
|
||||
take_profit=0.020,
|
||||
)
|
||||
|
||||
# Run in separate thread for each symbol
|
||||
import threading
|
||||
|
||||
thread = threading.Thread(
|
||||
target=strategy.run_live_strategy, args=(symbol,), daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mongodb_uri = "mongodb://localhost:27017/"
|
||||
symbols = ["BTCUSDT", "ETHUSDT"] # Add your symbols here
|
||||
run_live_trading(mongodb_uri, symbols)
|
Reference in New Issue
Block a user