From dae646cd2e37572b8fc25e122a1f67f48992b25a Mon Sep 17 00:00:00 2001 From: dongho Date: Tue, 24 Dec 2024 00:26:18 +0900 Subject: [PATCH] auto-trade skeleton added --- auto_trade.py | 447 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 447 insertions(+) create mode 100644 auto_trade.py diff --git a/auto_trade.py b/auto_trade.py new file mode 100644 index 0000000..7a1f8c5 --- /dev/null +++ b/auto_trade.py @@ -0,0 +1,447 @@ +from typing import Dict, List +import logging +import pandas as pd +import numpy as np +from pymongo import MongoClient +import time +from datetime import datetime, timedelta +import pytz + + +class CryptoTradingStrategy: + def __init__( + self, + sma_short: int = 20, + sma_long: int = 50, + rsi_base_period: int = 14, + rsi_base_overbought: float = 70, + rsi_base_oversold: float = 30, + volume_threshold: float = 1.5, + stop_loss: float = 0.02, + take_profit: float = 0.035, + ): + self.sma_short = sma_short + self.sma_long = sma_long + self.rsi_base_period = rsi_base_period + self.rsi_base_overbought = rsi_base_overbought + self.rsi_base_oversold = rsi_base_oversold + self.volume_threshold = volume_threshold + self.stop_loss = stop_loss + self.take_profit = take_profit + + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame: + """Calculate technical indicators.""" + # Calculate SMAs + df["SMA_short"] = df["close"].rolling(window=self.sma_short).mean() + df["SMA_long"] = df["close"].rolling(window=self.sma_long).mean() + + # Calculate ATR (Average True Range) for volatility + df["TR"] = (df["high"] - df["low"]).abs() # True Range = High - Low + df["ATR"] = df["TR"].rolling(window=14).mean() # 14-period ATR + + # Adaptive RSI period based on volatility + # Adaptive RSI period based on volatility + df["RSI_period"] = self.rsi_base_period * ( + 1 + (df["ATR"] / df["close"].rolling(window=20).mean()) + ) + + # Replace NaN or infinite values with a default RSI period + df["RSI_period"].fillna(self.rsi_base_period, inplace=True) + df["RSI_period"].replace([np.inf, -np.inf], self.rsi_base_period, inplace=True) + + # Convert to integer + df["RSI_period"] = df["RSI_period"].astype(int) + # Calculate adaptive RSI + df["RSI"] = 0.0 + for period in df["RSI_period"].unique(): + mask = df["RSI_period"] == period + delta = df.loc[mask, "close"].diff() + gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() + loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() + rs = gain / loss + df.loc[mask, "RSI"] = 100 - (100 / (1 + rs)) + + # Adaptive RSI thresholds based on Bollinger Band width + df["BB_width"] = (df["close"].rolling(window=20).std() * 2) / df[ + "close" + ].rolling(window=20).mean() + df["RSI_overbought"] = self.rsi_base_overbought + (df["BB_width"] * 20) + df["RSI_oversold"] = self.rsi_base_oversold - (df["BB_width"] * 20) + + # Volume analysis + df["Volume_MA"] = df["volume"].rolling(window=20).mean() + df["Volume_Ratio"] = df["volume"] / df["Volume_MA"] + + return df + + def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame: + """Generate trading signals based on indicators.""" + df["signal"] = 0 # 0: hold, 1: buy, -1: sell + + # Generate buy signals + buy_conditions = ( + (df["SMA_short"] > df["SMA_long"]) # Golden cross + & (df["RSI"] < df["RSI_oversold"]) # Adaptive oversold condition + & (df["Volume_Ratio"] > self.volume_threshold) # High volume + ) + df.loc[buy_conditions, "signal"] = 1 + + # Generate sell signals + sell_conditions = (df["SMA_short"] < df["SMA_long"]) | ( # Death cross + df["RSI"] > df["RSI_overbought"] + ) # Adaptive overbought condition + df.loc[sell_conditions, "signal"] = -1 + + return df + + def apply_risk_management(self, df: pd.DataFrame, position: Dict) -> Dict: + """Apply risk management rules to current position.""" + current_price = df["close"].iloc[-1] + + if position["in_position"]: + # Check stop loss + if current_price <= position["entry_price"] * (1 - self.stop_loss): + position["should_exit"] = True + position["exit_reason"] = "stop_loss" + + # Check take profit + elif current_price >= position["entry_price"] * (1 + self.take_profit): + position["should_exit"] = True + position["exit_reason"] = "take_profit" + + return position + + def execute_trades(self, df: pd.DataFrame) -> List[Dict]: + """Execute trading strategy and maintain positions.""" + trades = [] + position = { + "in_position": False, + "entry_price": 0, + "entry_time": None, + "should_exit": False, + "exit_reason": None, + } + + for i in range(len(df)): + current_data = df.iloc[i] + + # Update position status + if position["in_position"]: + position = self.apply_risk_management( + df.iloc[max(0, i - 10) : i + 1], position + ) + + # Exit position if necessary + if position["in_position"] and ( + position["should_exit"] or current_data["signal"] == -1 + ): + trade = { + "exit_time": current_data.name, + "exit_price": current_data["close"], + "exit_reason": position["exit_reason"] or "signal", + "profit_pct": (current_data["close"] - position["entry_price"]) + / position["entry_price"] + * 100, + } + trades.append({**position, **trade}) + position = { + "in_position": False, + "entry_price": 0, + "entry_time": None, + "should_exit": False, + "exit_reason": None, + } + + # Enter new position + elif not position["in_position"] and current_data["signal"] == 1: + position = { + "in_position": True, + "entry_price": current_data["close"], + "entry_time": current_data.name, + "should_exit": False, + "exit_reason": None, + } + + return trades + + def run_strategy(self, df: pd.DataFrame) -> tuple: + """Run the complete trading strategy.""" + # Prepare data + df = self.calculate_indicators(df) + df = self.generate_signals(df) + + # Execute trades + trades = self.execute_trades(df) + + # Calculate performance metrics + total_trades = len(trades) + winning_trades = len([t for t in trades if t["profit_pct"] > 0]) + total_return = sum(t["profit_pct"] for t in trades) + + metrics = { + "total_trades": total_trades, + "winning_trades": winning_trades, + "win_rate": winning_trades / total_trades if total_trades > 0 else 0, + "total_return": total_return, + "average_return": total_return / total_trades if total_trades > 0 else 0, + } + + return trades, metrics + + +class LiveCryptoTradingStrategy(CryptoTradingStrategy): + def __init__( + self, + mongodb_uri: str, + db_name: str, + initial_balance: float = 10000, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.client = MongoClient(mongodb_uri) + self.db = self.client[db_name] + self.position_collection = self.db["positions"] + self.trade_collection = self.db["trades"] + self.data_collection = self.db["market_data"] + self.asset_collection = self.db["assets"] + + # Initialize assets if not exists + if not self.asset_collection.find_one({"asset_type": "USD"}): + self.asset_collection.insert_one( + { + "asset_type": "USD", + "amount": initial_balance, + "last_updated": datetime.now(pytz.UTC), + } + ) + + def get_balance(self, asset_type: str = "USD") -> float: + """Get current balance of specified asset""" + asset_doc = self.asset_collection.find_one({"asset_type": asset_type}) + return asset_doc["amount"] if asset_doc else 0 + + def update_balance(self, asset_type: str, amount: float): + """Update balance of specified asset""" + self.asset_collection.update_one( + {"asset_type": asset_type}, + {"$set": {"amount": amount, "last_updated": datetime.now(pytz.UTC)}}, + upsert=True, + ) + + def get_latest_data(self, symbol: str, lookback_minutes: int = 200) -> pd.DataFrame: + """Fetch latest data from MongoDB""" + end_time = datetime.now(pytz.UTC) + start_time = end_time - timedelta(minutes=lookback_minutes) + + cursor = self.data_collection.find( + {"symbol": symbol, "timestamp": {"$gte": start_time, "$lte": end_time}} + ).sort("timestamp", 1) + + data = list(cursor) + if not data: + return None + + df = pd.DataFrame(data) + df.set_index("timestamp", inplace=True) + return df + + def execute_trade(self, symbol: str, side: str, amount: float, price: float): + """Execute trade and update balances""" + base_asset = symbol[:-4] # BTCUSDT -> BTC + quote_asset = symbol[-4:] # BTCUSDT -> USDT + + if side == "buy": + # Check if enough quote asset (USDT) available + quote_balance = self.get_balance(quote_asset) + cost = amount * price + + if quote_balance >= cost: + # Update quote asset balance (decrease) + self.update_balance(quote_asset, quote_balance - cost) + # Update base asset balance (increase) + base_balance = self.get_balance(base_asset) + self.update_balance(base_asset, base_balance + amount) + return True + return False + + elif side == "sell": + # Check if enough base asset available + base_balance = self.get_balance(base_asset) + + if base_balance >= amount: + # Update base asset balance (decrease) + self.update_balance(base_asset, base_balance - amount) + # Update quote asset balance (increase) + quote_balance = self.get_balance(quote_asset) + self.update_balance(quote_asset, quote_balance + (amount * price)) + return True + return False + + def update_position(self, symbol: str, position: Dict): + """Update position in MongoDB""" + self.position_collection.update_one( + {"symbol": symbol}, + { + "$set": { + "in_position": position["in_position"], + "entry_price": position["entry_price"], + "entry_time": position["entry_time"], + "position_type": position.get("position_type", "long"), + "last_updated": datetime.now(pytz.UTC), + } + }, + upsert=True, + ) + + def log_trade(self, symbol: str, trade: Dict): + """Log completed trade to MongoDB""" + trade_doc = { + "symbol": symbol, + "entry_time": trade["entry_time"], + "exit_time": trade["exit_time"], + "entry_price": trade["entry_price"], + "exit_price": trade["exit_price"], + "position_type": trade.get("position_type", "long"), + "profit_pct": trade["profit_pct"], + "exit_reason": trade["exit_reason"], + } + self.trade_collection.insert_one(trade_doc) + + def run_live_strategy( + self, symbol: str, trade_amount: float = 0.1, interval_seconds: int = 60 + ): + """Run strategy in live mode""" + self.logger.info(f"Starting live trading for {symbol}") + + # Get or initialize position + position_doc = self.position_collection.find_one({"symbol": symbol}) + position = { + "in_position": position_doc["in_position"] if position_doc else False, + "entry_price": position_doc.get("entry_price", 0), + "entry_time": position_doc.get("entry_time"), + "should_exit": False, + "exit_reason": None, + "position_type": position_doc.get("position_type", "long"), + } + + while True: + try: + # Get latest data + df = self.get_latest_data(symbol) + if df is None or len(df) < self.sma_long: + self.logger.warning("Insufficient data, waiting...") + time.sleep(interval_seconds) + continue + + # Calculate indicators and signals + df = self.calculate_indicators(df) + df = self.generate_signals(df) + + current_data = df.iloc[-1] + + # Update position status + if position["in_position"]: + position = self.apply_risk_management(df.iloc[-10:], position) + + # Handle position exit + if position["in_position"] and ( + position["should_exit"] or current_data["signal"] == -1 + ): + if self.execute_trade( + symbol, "sell", trade_amount, current_data["close"] + ): + # Log trade and update position as before + trade = { + "exit_time": current_data.name, + "exit_price": current_data["close"], + "exit_reason": position["exit_reason"] or "signal", + "profit_pct": ( + current_data["close"] - position["entry_price"] + ) + / position["entry_price"] + * 100, + "position_type": position["position_type"], + "amount": trade_amount, + "usd_value": trade_amount * current_data["close"], + } + self.log_trade(symbol, {**position, **trade}) + + # Update position + position = { + "in_position": False, + "entry_price": 0, + "entry_time": None, + "should_exit": False, + "exit_reason": None, + } + self.update_position(symbol, position) + + # Log balances + self.logger.info( + f"Exited position. New balances: " + f"USD: {self.get_balance('USDT')}, " + f"{symbol[:-4]}: {self.get_balance(symbol[:-4])}" + ) + + # Handle position entry + elif not position["in_position"] and current_data["signal"] == 1: + if self.execute_trade( + symbol, "buy", trade_amount, current_data["close"] + ): + position = { + "in_position": True, + "entry_price": current_data["close"], + "entry_time": current_data.name, + "should_exit": False, + "exit_reason": None, + "position_type": "long", + "amount": trade_amount, + } + self.update_position(symbol, position) + + # Log balances + self.logger.info( + f"Entered position. New balances: " + f"USD: {self.get_balance('USDT')}, " + f"{symbol[:-4]}: {self.get_balance(symbol[:-4])}" + ) + + except Exception as e: + self.logger.error(f"Error in live trading loop: {str(e)}") + + time.sleep(interval_seconds) + + +def run_live_trading(mongodb_uri: str, symbols: List[str]): + """Run live trading for multiple symbols""" + for symbol in symbols: + strategy = LiveCryptoTradingStrategy( + mongodb_uri=mongodb_uri, + db_name="crypto_trading", + sma_short=20, + sma_long=100, + rsi_base_period=21, + rsi_base_overbought=70, + rsi_base_oversold=30, + volume_threshold=1.0, + stop_loss=0.01, + take_profit=0.020, + ) + + # Run in separate thread for each symbol + import threading + + thread = threading.Thread( + target=strategy.run_live_strategy, args=(symbol,), daemon=True + ) + thread.start() + + +if __name__ == "__main__": + mongodb_uri = "mongodb://localhost:27017/" + symbols = ["BTCUSDT", "ETHUSDT"] # Add your symbols here + run_live_trading(mongodb_uri, symbols)