auto-trade skeleton added
All checks were successful
SonarQube Scan / SonarQube Trigger (push) Successful in 4m10s

This commit is contained in:
dongho
2024-12-24 00:26:18 +09:00
parent 82e1b74aee
commit dae646cd2e

447
auto_trade.py Normal file
View File

@ -0,0 +1,447 @@
from typing import Dict, List
import logging
import pandas as pd
import numpy as np
from pymongo import MongoClient
import time
from datetime import datetime, timedelta
import pytz
class CryptoTradingStrategy:
def __init__(
self,
sma_short: int = 20,
sma_long: int = 50,
rsi_base_period: int = 14,
rsi_base_overbought: float = 70,
rsi_base_oversold: float = 30,
volume_threshold: float = 1.5,
stop_loss: float = 0.02,
take_profit: float = 0.035,
):
self.sma_short = sma_short
self.sma_long = sma_long
self.rsi_base_period = rsi_base_period
self.rsi_base_overbought = rsi_base_overbought
self.rsi_base_oversold = rsi_base_oversold
self.volume_threshold = volume_threshold
self.stop_loss = stop_loss
self.take_profit = take_profit
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Calculate technical indicators."""
# Calculate SMAs
df["SMA_short"] = df["close"].rolling(window=self.sma_short).mean()
df["SMA_long"] = df["close"].rolling(window=self.sma_long).mean()
# Calculate ATR (Average True Range) for volatility
df["TR"] = (df["high"] - df["low"]).abs() # True Range = High - Low
df["ATR"] = df["TR"].rolling(window=14).mean() # 14-period ATR
# Adaptive RSI period based on volatility
# Adaptive RSI period based on volatility
df["RSI_period"] = self.rsi_base_period * (
1 + (df["ATR"] / df["close"].rolling(window=20).mean())
)
# Replace NaN or infinite values with a default RSI period
df["RSI_period"].fillna(self.rsi_base_period, inplace=True)
df["RSI_period"].replace([np.inf, -np.inf], self.rsi_base_period, inplace=True)
# Convert to integer
df["RSI_period"] = df["RSI_period"].astype(int)
# Calculate adaptive RSI
df["RSI"] = 0.0
for period in df["RSI_period"].unique():
mask = df["RSI_period"] == period
delta = df.loc[mask, "close"].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
df.loc[mask, "RSI"] = 100 - (100 / (1 + rs))
# Adaptive RSI thresholds based on Bollinger Band width
df["BB_width"] = (df["close"].rolling(window=20).std() * 2) / df[
"close"
].rolling(window=20).mean()
df["RSI_overbought"] = self.rsi_base_overbought + (df["BB_width"] * 20)
df["RSI_oversold"] = self.rsi_base_oversold - (df["BB_width"] * 20)
# Volume analysis
df["Volume_MA"] = df["volume"].rolling(window=20).mean()
df["Volume_Ratio"] = df["volume"] / df["Volume_MA"]
return df
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
"""Generate trading signals based on indicators."""
df["signal"] = 0 # 0: hold, 1: buy, -1: sell
# Generate buy signals
buy_conditions = (
(df["SMA_short"] > df["SMA_long"]) # Golden cross
& (df["RSI"] < df["RSI_oversold"]) # Adaptive oversold condition
& (df["Volume_Ratio"] > self.volume_threshold) # High volume
)
df.loc[buy_conditions, "signal"] = 1
# Generate sell signals
sell_conditions = (df["SMA_short"] < df["SMA_long"]) | ( # Death cross
df["RSI"] > df["RSI_overbought"]
) # Adaptive overbought condition
df.loc[sell_conditions, "signal"] = -1
return df
def apply_risk_management(self, df: pd.DataFrame, position: Dict) -> Dict:
"""Apply risk management rules to current position."""
current_price = df["close"].iloc[-1]
if position["in_position"]:
# Check stop loss
if current_price <= position["entry_price"] * (1 - self.stop_loss):
position["should_exit"] = True
position["exit_reason"] = "stop_loss"
# Check take profit
elif current_price >= position["entry_price"] * (1 + self.take_profit):
position["should_exit"] = True
position["exit_reason"] = "take_profit"
return position
def execute_trades(self, df: pd.DataFrame) -> List[Dict]:
"""Execute trading strategy and maintain positions."""
trades = []
position = {
"in_position": False,
"entry_price": 0,
"entry_time": None,
"should_exit": False,
"exit_reason": None,
}
for i in range(len(df)):
current_data = df.iloc[i]
# Update position status
if position["in_position"]:
position = self.apply_risk_management(
df.iloc[max(0, i - 10) : i + 1], position
)
# Exit position if necessary
if position["in_position"] and (
position["should_exit"] or current_data["signal"] == -1
):
trade = {
"exit_time": current_data.name,
"exit_price": current_data["close"],
"exit_reason": position["exit_reason"] or "signal",
"profit_pct": (current_data["close"] - position["entry_price"])
/ position["entry_price"]
* 100,
}
trades.append({**position, **trade})
position = {
"in_position": False,
"entry_price": 0,
"entry_time": None,
"should_exit": False,
"exit_reason": None,
}
# Enter new position
elif not position["in_position"] and current_data["signal"] == 1:
position = {
"in_position": True,
"entry_price": current_data["close"],
"entry_time": current_data.name,
"should_exit": False,
"exit_reason": None,
}
return trades
def run_strategy(self, df: pd.DataFrame) -> tuple:
"""Run the complete trading strategy."""
# Prepare data
df = self.calculate_indicators(df)
df = self.generate_signals(df)
# Execute trades
trades = self.execute_trades(df)
# Calculate performance metrics
total_trades = len(trades)
winning_trades = len([t for t in trades if t["profit_pct"] > 0])
total_return = sum(t["profit_pct"] for t in trades)
metrics = {
"total_trades": total_trades,
"winning_trades": winning_trades,
"win_rate": winning_trades / total_trades if total_trades > 0 else 0,
"total_return": total_return,
"average_return": total_return / total_trades if total_trades > 0 else 0,
}
return trades, metrics
class LiveCryptoTradingStrategy(CryptoTradingStrategy):
def __init__(
self,
mongodb_uri: str,
db_name: str,
initial_balance: float = 10000,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.client = MongoClient(mongodb_uri)
self.db = self.client[db_name]
self.position_collection = self.db["positions"]
self.trade_collection = self.db["trades"]
self.data_collection = self.db["market_data"]
self.asset_collection = self.db["assets"]
# Initialize assets if not exists
if not self.asset_collection.find_one({"asset_type": "USD"}):
self.asset_collection.insert_one(
{
"asset_type": "USD",
"amount": initial_balance,
"last_updated": datetime.now(pytz.UTC),
}
)
def get_balance(self, asset_type: str = "USD") -> float:
"""Get current balance of specified asset"""
asset_doc = self.asset_collection.find_one({"asset_type": asset_type})
return asset_doc["amount"] if asset_doc else 0
def update_balance(self, asset_type: str, amount: float):
"""Update balance of specified asset"""
self.asset_collection.update_one(
{"asset_type": asset_type},
{"$set": {"amount": amount, "last_updated": datetime.now(pytz.UTC)}},
upsert=True,
)
def get_latest_data(self, symbol: str, lookback_minutes: int = 200) -> pd.DataFrame:
"""Fetch latest data from MongoDB"""
end_time = datetime.now(pytz.UTC)
start_time = end_time - timedelta(minutes=lookback_minutes)
cursor = self.data_collection.find(
{"symbol": symbol, "timestamp": {"$gte": start_time, "$lte": end_time}}
).sort("timestamp", 1)
data = list(cursor)
if not data:
return None
df = pd.DataFrame(data)
df.set_index("timestamp", inplace=True)
return df
def execute_trade(self, symbol: str, side: str, amount: float, price: float):
"""Execute trade and update balances"""
base_asset = symbol[:-4] # BTCUSDT -> BTC
quote_asset = symbol[-4:] # BTCUSDT -> USDT
if side == "buy":
# Check if enough quote asset (USDT) available
quote_balance = self.get_balance(quote_asset)
cost = amount * price
if quote_balance >= cost:
# Update quote asset balance (decrease)
self.update_balance(quote_asset, quote_balance - cost)
# Update base asset balance (increase)
base_balance = self.get_balance(base_asset)
self.update_balance(base_asset, base_balance + amount)
return True
return False
elif side == "sell":
# Check if enough base asset available
base_balance = self.get_balance(base_asset)
if base_balance >= amount:
# Update base asset balance (decrease)
self.update_balance(base_asset, base_balance - amount)
# Update quote asset balance (increase)
quote_balance = self.get_balance(quote_asset)
self.update_balance(quote_asset, quote_balance + (amount * price))
return True
return False
def update_position(self, symbol: str, position: Dict):
"""Update position in MongoDB"""
self.position_collection.update_one(
{"symbol": symbol},
{
"$set": {
"in_position": position["in_position"],
"entry_price": position["entry_price"],
"entry_time": position["entry_time"],
"position_type": position.get("position_type", "long"),
"last_updated": datetime.now(pytz.UTC),
}
},
upsert=True,
)
def log_trade(self, symbol: str, trade: Dict):
"""Log completed trade to MongoDB"""
trade_doc = {
"symbol": symbol,
"entry_time": trade["entry_time"],
"exit_time": trade["exit_time"],
"entry_price": trade["entry_price"],
"exit_price": trade["exit_price"],
"position_type": trade.get("position_type", "long"),
"profit_pct": trade["profit_pct"],
"exit_reason": trade["exit_reason"],
}
self.trade_collection.insert_one(trade_doc)
def run_live_strategy(
self, symbol: str, trade_amount: float = 0.1, interval_seconds: int = 60
):
"""Run strategy in live mode"""
self.logger.info(f"Starting live trading for {symbol}")
# Get or initialize position
position_doc = self.position_collection.find_one({"symbol": symbol})
position = {
"in_position": position_doc["in_position"] if position_doc else False,
"entry_price": position_doc.get("entry_price", 0),
"entry_time": position_doc.get("entry_time"),
"should_exit": False,
"exit_reason": None,
"position_type": position_doc.get("position_type", "long"),
}
while True:
try:
# Get latest data
df = self.get_latest_data(symbol)
if df is None or len(df) < self.sma_long:
self.logger.warning("Insufficient data, waiting...")
time.sleep(interval_seconds)
continue
# Calculate indicators and signals
df = self.calculate_indicators(df)
df = self.generate_signals(df)
current_data = df.iloc[-1]
# Update position status
if position["in_position"]:
position = self.apply_risk_management(df.iloc[-10:], position)
# Handle position exit
if position["in_position"] and (
position["should_exit"] or current_data["signal"] == -1
):
if self.execute_trade(
symbol, "sell", trade_amount, current_data["close"]
):
# Log trade and update position as before
trade = {
"exit_time": current_data.name,
"exit_price": current_data["close"],
"exit_reason": position["exit_reason"] or "signal",
"profit_pct": (
current_data["close"] - position["entry_price"]
)
/ position["entry_price"]
* 100,
"position_type": position["position_type"],
"amount": trade_amount,
"usd_value": trade_amount * current_data["close"],
}
self.log_trade(symbol, {**position, **trade})
# Update position
position = {
"in_position": False,
"entry_price": 0,
"entry_time": None,
"should_exit": False,
"exit_reason": None,
}
self.update_position(symbol, position)
# Log balances
self.logger.info(
f"Exited position. New balances: "
f"USD: {self.get_balance('USDT')}, "
f"{symbol[:-4]}: {self.get_balance(symbol[:-4])}"
)
# Handle position entry
elif not position["in_position"] and current_data["signal"] == 1:
if self.execute_trade(
symbol, "buy", trade_amount, current_data["close"]
):
position = {
"in_position": True,
"entry_price": current_data["close"],
"entry_time": current_data.name,
"should_exit": False,
"exit_reason": None,
"position_type": "long",
"amount": trade_amount,
}
self.update_position(symbol, position)
# Log balances
self.logger.info(
f"Entered position. New balances: "
f"USD: {self.get_balance('USDT')}, "
f"{symbol[:-4]}: {self.get_balance(symbol[:-4])}"
)
except Exception as e:
self.logger.error(f"Error in live trading loop: {str(e)}")
time.sleep(interval_seconds)
def run_live_trading(mongodb_uri: str, symbols: List[str]):
"""Run live trading for multiple symbols"""
for symbol in symbols:
strategy = LiveCryptoTradingStrategy(
mongodb_uri=mongodb_uri,
db_name="crypto_trading",
sma_short=20,
sma_long=100,
rsi_base_period=21,
rsi_base_overbought=70,
rsi_base_oversold=30,
volume_threshold=1.0,
stop_loss=0.01,
take_profit=0.020,
)
# Run in separate thread for each symbol
import threading
thread = threading.Thread(
target=strategy.run_live_strategy, args=(symbol,), daemon=True
)
thread.start()
if __name__ == "__main__":
mongodb_uri = "mongodb://localhost:27017/"
symbols = ["BTCUSDT", "ETHUSDT"] # Add your symbols here
run_live_trading(mongodb_uri, symbols)