Source code for curvesim.templates.strategy

from abc import ABC, abstractmethod
from typing import Optional, Type

from curvesim.logging import get_logger

from .log import Log
from .trader import Trader

logger = get_logger(__name__)


[docs]class Strategy(ABC): """ A Strategy defines the trading approach used during each step of a simulation. It executes the trades using an injected `Trader` class and then logs the changes using the injected `Log` class. Class Attributes ---------------- trader_class : :class:`~curvesim.templates.Trader` Class for creating trader instances. log_class : :class:`~curvesim.templates.Log` Class for creating log instances. Attributes ---------- metrics : List[Metric] A list of metrics used to evaluate the performance of the strategy. """ # These classes should be injected in child classes # to create the desired behavior. trader_class: Optional[Type[Trader]] = None log_class: Optional[Type[Log]] = None
[docs] def __init__(self, metrics): """ Parameters ---------- metrics : List[Metric] A list of metrics used to evaluate the performance of the strategy. """ self.metrics = metrics
def __call__(self, pool, parameters, price_sampler): """ Computes and executes trades at each timestep. Parameters ---------- pool : :class:`~curvesim.pipelines.templates.SimPool` The pool to be traded against. parameters : dict Current pool parameters from the param_sampler (only used for logging/display). price_sampler : iterable Iterable that for each timestep returns market data used by the trader. Returns ------- metrics : tuple of lists """ # pylint: disable=not-callable trader = self.trader_class(pool) log = self.log_class(pool, self.metrics) parameters = parameters or "no parameter changes" logger.info("[%s] Simulating with %s", pool.symbol, parameters) pool.prepare_for_run(price_sampler.prices) for sample in price_sampler: pool.prepare_for_trades(sample.timestamp) trader_args = self._get_trader_inputs(sample) trade_data = trader.process_time_sample(*trader_args) log.update(price_sample=sample, trade_data=trade_data) return log.compute_metrics() @abstractmethod def _get_trader_inputs(self, sample): """ Process the price sample into appropriate inputs for the trader instance. """ raise NotImplementedError