Support refresh entities

This commit is contained in:
foolcage
2023-02-20 18:11:23 +08:00
parent 5dffad4fe2
commit 5351207817
38 changed files with 1536 additions and 514 deletions

View File

@@ -29,7 +29,7 @@ class BullAndUpFactor(MacdFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -63,7 +63,7 @@ class BullAndUpFactor(MacdFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,

View File

@@ -5,13 +5,12 @@ from typing import Type
from examples.utils import add_to_eastmoney
from zvt import zvt_config
from zvt.api import get_top_volume_entities, get_top_performance_entities, TopType
from zvt.api import get_top_volume_entities, TopType
from zvt.api.kdata import get_latest_kdata_date, get_kdata_schema, default_adjust_type
from zvt.api.selector import get_entity_ids_by_filter
from zvt.api.stats import get_top_performance_entities_by_periods
from zvt.contract import IntervalLevel
from zvt.contract.api import get_entities, get_entity_schema
from zvt.contract.factor import Factor
from zvt.factors import TargetSelector, SelectMode
from zvt.contract.factor import Factor, TargetType
from zvt.informer import EmailInformer
from zvt.utils import next_date
@@ -109,9 +108,6 @@ def report_targets(
current_entity_pool = set(factor_kv.pop("entity_ids"))
# add the factor
my_selector = TargetSelector(
start_timestamp=start_timestamp, end_timestamp=target_date, select_mode=SelectMode.condition_or
)
entity_schema = get_entity_schema(entity_type=entity_type)
tech_factor = factor_cls(
entity_schema=entity_schema,
@@ -123,11 +119,8 @@ def report_targets(
adjust_type=adjust_type,
**factor_kv,
)
my_selector.add_factor(tech_factor)
my_selector.run()
long_stocks = my_selector.get_open_long_targets(timestamp=target_date)
long_stocks = tech_factor.get_targets(timestamp=target_date, target_type=TargetType.positive)
inform(
informer,
@@ -174,79 +167,25 @@ def report_top_entities(
while error_count <= 10:
try:
if periods is None:
periods = [7, 30, 365]
if not adjust_type:
adjust_type = default_adjust_type(entity_type=entity_type)
kdata_schema = get_kdata_schema(entity_type=entity_type, adjust_type=adjust_type)
entity_schema = get_entity_schema(entity_type=entity_type)
target_date = get_latest_kdata_date(
provider=data_provider, entity_type=entity_type, adjust_type=adjust_type
)
filter_entity_ids = get_entity_ids_by_filter(
provider=entity_provider,
ignore_st=ignore_st,
selected = get_top_performance_entities_by_periods(
entity_provider=entity_provider,
data_provider=data_provider,
periods=periods,
ignore_new_stock=ignore_new_stock,
entity_schema=entity_schema,
target_date=target_date,
ignore_st=ignore_st,
entity_ids=entity_ids,
entity_type=entity_type,
adjust_type=adjust_type,
top_count=top_count,
turnover_threshold=turnover_threshold,
turnover_rate_threshold=turnover_rate_threshold,
return_type=return_type,
)
if not filter_entity_ids:
msg = f"{entity_type} no entity_ids selected"
logger.error(msg)
informer.send_message(zvt_config["email_username"], "report_top_stats error", msg)
return
filter_turnover_df = kdata_schema.query_data(
filters=[
kdata_schema.turnover >= turnover_threshold,
kdata_schema.turnover_rate >= turnover_rate_threshold,
],
provider=data_provider,
start_timestamp=target_date,
index="entity_id",
columns=["entity_id", "code"],
)
if filter_entity_ids:
filter_entity_ids = set(filter_entity_ids) & set(filter_turnover_df.index.tolist())
else:
filter_entity_ids = filter_turnover_df.index.tolist()
if not filter_entity_ids:
msg = f"{entity_type} no entity_ids selected"
logger.error(msg)
informer.send_message(zvt_config["email_username"], "report_top_stats error", msg)
return
logger.info(f"{entity_type} filter_entity_ids size: {len(filter_entity_ids)}")
filters = [kdata_schema.entity_id.in_(filter_entity_ids)]
selected = []
for i, period in enumerate(periods):
interval = period
if target_date.weekday() + 1 < interval:
interval = interval + 2
start = next_date(target_date, -interval)
positive_df, negative_df = get_top_performance_entities(
entity_type=entity_type,
start_timestamp=start,
kdata_filters=filters,
pct=1,
show_name=True,
entity_provider=entity_provider,
data_provider=data_provider,
return_type=return_type,
)
if return_type == TopType.positive:
df = positive_df
else:
df = negative_df
selected = selected + df.index[:top_count].tolist()
selected = list(dict.fromkeys(selected))
inform(
informer,
entity_ids=selected,

View File

@@ -51,7 +51,7 @@ class DragonTigerFactor(TechnicalFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -80,7 +80,7 @@ class DragonTigerFactor(TechnicalFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,

View File

@@ -6,7 +6,6 @@ import pandas as pd
from zvt.contract import IntervalLevel
from zvt.contract.factor import Factor, Transformer, Accumulator
from zvt.domain import Stock, DragonAndTiger
from zvt.factors import TargetSelector
from zvt.trader import StockTrader
@@ -27,7 +26,7 @@ class DragonTigerFactor(Factor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -56,7 +55,7 @@ class DragonTigerFactor(Factor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
@@ -75,20 +74,10 @@ class DragonTigerFactor(Factor):
class MyTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="em",
)
myselector.add_factor(
return [
DragonTigerFactor(
entity_ids=entity_ids,
exchanges=exchanges,
@@ -96,9 +85,7 @@ class MyTrader(StockTrader):
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
)
)
self.selectors.append(myselector)
]
if __name__ == "__main__":

View File

@@ -5,7 +5,7 @@ from zvt.api import get_top_volume_entities
from zvt.api.stats import get_top_fund_holding_stocks
from zvt.api.trader_info_api import clear_trader
from zvt.contract import IntervalLevel
from zvt.factors import TargetSelector, GoldCrossFactor, BullFactor
from zvt.factors import GoldCrossFactor, BullFactor
from zvt.trader import StockTrader
from zvt.utils.time_utils import split_time_interval, next_date
@@ -13,62 +13,33 @@ logger = logging.getLogger(__name__)
class MultipleLevelTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
start_timestamp = next_date(start_timestamp, -50)
# 周线策略
week_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1WEEK,
provider="joinquant",
)
week_bull_factor = BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1WEEK,
)
week_selector.add_factor(week_bull_factor)
# 日线策略
day_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1DAY,
provider="joinquant",
)
day_gold_cross_factor = GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
)
day_selector.add_factor(day_gold_cross_factor)
# 同时使用日线,周线级别
self.selectors.append(day_selector)
self.selectors.append(week_selector)
return [
BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1WEEK,
),
GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
),
]
if __name__ == "__main__":

View File

@@ -1,27 +1,16 @@
# -*- coding: utf-8 -*-
from zvt.contract import IntervalLevel
from zvt.factors import CrossMaFactor
from zvt.factors.target_selector import TargetSelector
from zvt.factors.macd import BullFactor
from zvt.trader.trader import StockTrader
class MyMaTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
)
myselector.add_factor(
return [
CrossMaFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
@@ -32,26 +21,14 @@ class MyMaTrader(StockTrader):
windows=[5, 10],
need_persist=False,
)
)
self.selectors.append(myselector)
]
class MyBullTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
)
myselector.add_factor(
return [
BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
@@ -61,9 +38,7 @@ class MyBullTrader(StockTrader):
end_timestamp=end_timestamp,
adjust_type="hfq",
)
)
self.selectors.append(myselector)
]
if __name__ == "__main__":

View File

@@ -4,7 +4,8 @@ from typing import List, Tuple
import pandas as pd
from zvt.contract import IntervalLevel
from zvt.factors import TargetSelector, GoldCrossFactor
from zvt.contract.factor import Factor
from zvt.factors import GoldCrossFactor
from zvt.trader import TradingSignal
from zvt.trader.trader import StockTrader
@@ -15,35 +16,23 @@ from zvt.utils import next_date
class MacdDayTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
# 日线策略
start_timestamp = next_date(start_timestamp, -50)
day_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1DAY,
provider="joinquant",
)
day_gold_cross_factor = GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
)
day_selector.add_factor(day_gold_cross_factor)
self.selectors.append(day_selector)
return [
GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
)
]
def on_profit_control(self):
# 覆盖该函数做止盈 止损
@@ -86,10 +75,10 @@ class MacdDayTrader(StockTrader):
return super().short_position_control()
def on_targets_filtered(
self, timestamp, level, selector: TargetSelector, long_targets: List[str], short_targets: List[str]
self, timestamp, level, factor: Factor, long_targets: List[str], short_targets: List[str]
) -> Tuple[List[str], List[str]]:
# 过滤某级别选出的 标的
return super().on_targets_filtered(timestamp, level, selector, long_targets, short_targets)
return super().on_targets_filtered(timestamp, level, factor, long_targets, short_targets)
if __name__ == "__main__":

View File

@@ -2,7 +2,7 @@
from typing import List, Tuple
from zvt.contract import IntervalLevel
from zvt.factors import TargetSelector, GoldCrossFactor
from zvt.factors import GoldCrossFactor
from zvt.trader.trader import StockTrader
@@ -10,60 +10,32 @@ from zvt.trader.trader import StockTrader
# dataschema: Stock1dHfqKdata Stock1wkHfqKdata
# provider: joinquant
class MultipleLevelTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
# 线策略
week_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1WEEK,
provider="joinquant",
)
week_gold_cross_factor = GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1WEEK,
)
week_selector.add_factor(week_gold_cross_factor)
# 日线策略
day_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1DAY,
provider="joinquant",
)
day_gold_cross_factor = GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
)
day_selector.add_factor(day_gold_cross_factor)
# 同时使用日线,周线级别
self.selectors.append(day_selector)
self.selectors.append(week_selector)
# 同时使用周线和日线策略
return [
GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1WEEK,
),
GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
),
]
def on_targets_selected_from_levels(self, timestamp) -> Tuple[List[str], List[str]]:
# 过滤多级别做 多/空 的标的

View File

@@ -2,10 +2,13 @@ requests==2.28.2
SQLAlchemy == 1.4.46
pandas==1.5.3
arrow==1.2.3
openpyxl==3.1.0
openpyxl==3.1.1
demjson3==3.0.6
marshmallow-sqlalchemy==0.28.1
marshmallow==3.19.0
plotly==5.13.0
dash==2.8.1
jqdatapy==0.1.8
dash-bootstrap-components==1.3.1
dash_daq==0.5.0
scikit-learn==1.2.1

6
sql/reduce_size.sql Normal file
View File

@@ -0,0 +1,6 @@
-- k线数据去除索引方便传输原始数据重跑zvt会重建
-- 再压缩一下大小为原来的1/10
drop index stock_1d_hfq_kdata_entity_id_index;
drop index stock_1d_hfq_kdata_code_index;
drop index stock_1d_hfq_kdata_timestamp_index;
VACUUM;

View File

@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
import importlib
import json
import logging
import os
import pkgutil
import pprint
from logging.handlers import RotatingFileHandler
@@ -153,6 +155,16 @@ def init_config(pkg_name: str = None, current_config: dict = None, **kwargs) ->
return current_config
def init_plugins():
for finder, name, ispkg in pkgutil.iter_modules():
if name.startswith("zvt_"):
try:
_plugins[name] = importlib.import_module(name)
except Exception as e:
logger.warning(f"failed to load plugin {name}", e)
logger.info(f"loaded plugins:{_plugins}")
if os.getenv("TESTING_ZVT"):
init_env(zvt_home=ZVT_TEST_HOME)

View File

@@ -6,15 +6,15 @@ from typing import Union
import pandas as pd
from zvt.api.kdata import get_kdata_schema, default_adjust_type
from zvt.api.kdata import get_kdata_schema, default_adjust_type, get_latest_kdata_date
from zvt.api.selector import get_entity_ids_by_filter
from zvt.api.utils import get_recent_report_date
from zvt.contract import Mixin, AdjustType
from zvt.contract.api import decode_entity_id, get_entity_schema, get_entity_ids
from zvt.contract.drawer import Drawer
from zvt.domain import FundStock, StockValuation, BlockStock, Block
from zvt.factors import TechnicalFactor
from zvt.utils import now_pd_timestamp, next_date, pd_is_not_null
from zvt.utils.time_utils import month_start_end_ranges, to_time_str, pre_month_end_date
from zvt.utils.time_utils import month_start_end_ranges, to_time_str
logger = logging.getLogger(__name__)
@@ -53,6 +53,91 @@ def get_top_performance_by_month(
yield start_timestamp, end_timestamp, top
def get_top_performance_entities_by_periods(
entity_provider,
data_provider,
target_date=None,
periods=None,
ignore_new_stock=True,
ignore_st=True,
entity_ids=None,
entity_type="stock",
adjust_type=None,
top_count=50,
turnover_threshold=100000000,
turnover_rate_threshold=0.02,
return_type=TopType.positive,
):
if periods is None:
periods = [*range(1, 21)]
if not adjust_type:
adjust_type = default_adjust_type(entity_type=entity_type)
kdata_schema = get_kdata_schema(entity_type=entity_type, adjust_type=adjust_type)
entity_schema = get_entity_schema(entity_type=entity_type)
if not target_date:
target_date = get_latest_kdata_date(provider=data_provider, entity_type=entity_type, adjust_type=adjust_type)
filter_entity_ids = get_entity_ids_by_filter(
provider=entity_provider,
ignore_st=ignore_st,
ignore_new_stock=ignore_new_stock,
entity_schema=entity_schema,
target_date=target_date,
entity_ids=entity_ids,
)
if not filter_entity_ids:
return []
filter_turnover_df = kdata_schema.query_data(
filters=[
kdata_schema.turnover >= turnover_threshold,
kdata_schema.turnover_rate >= turnover_rate_threshold,
],
provider=data_provider,
start_timestamp=target_date,
index="entity_id",
columns=["entity_id", "code"],
)
if filter_entity_ids:
filter_entity_ids = set(filter_entity_ids) & set(filter_turnover_df.index.tolist())
else:
filter_entity_ids = filter_turnover_df.index.tolist()
if not filter_entity_ids:
return []
logger.info(f"{entity_type} filter_entity_ids size: {len(filter_entity_ids)}")
filters = [kdata_schema.entity_id.in_(filter_entity_ids)]
selected = []
for i, period in enumerate(periods):
interval = period
if target_date.weekday() + 1 < interval:
interval = interval + 2
start = next_date(target_date, -interval)
positive_df, negative_df = get_top_performance_entities(
entity_type=entity_type,
start_timestamp=start,
end_timestamp=target_date,
kdata_filters=filters,
pct=1,
show_name=True,
entity_provider=entity_provider,
data_provider=data_provider,
return_type=return_type,
)
if return_type == TopType.positive:
df = positive_df
else:
df = negative_df
if pd_is_not_null(df):
selected = selected + df.index[:top_count].tolist()
selected = list(dict.fromkeys(selected))
return selected
def get_top_performance_entities(
entity_type="stock",
start_timestamp=None,

View File

@@ -104,7 +104,7 @@ class AccountStatsReader(DataReader):
level,
category_field="trader_name",
time_field="timestamp",
computing_window=None,
keep_window=None,
)
def draw_line(self, show=True):
@@ -155,7 +155,7 @@ class OrderReader(DataReader):
level,
category_field="trader_name",
time_field="timestamp",
computing_window=None,
keep_window=None,
)

View File

@@ -2,6 +2,7 @@
import json
import logging
import time
from enum import Enum
from typing import List, Union, Optional, Type
import pandas as pd
@@ -13,10 +14,17 @@ from zvt.contract.base_service import EntityStateService
from zvt.contract.reader import DataReader, DataListener
from zvt.contract.schema import Mixin, TradableEntity
from zvt.contract.zvt_info import FactorState
from zvt.utils import to_pd_timestamp
from zvt.utils.pd_utils import pd_is_not_null, drop_continue_duplicate, is_filter_result_df, is_score_result_df
from zvt.utils.str_utils import to_snake_str
class TargetType(Enum):
positive = "positive"
negative = "negative"
keep = "keep"
class Indicator(object):
def __init__(self) -> None:
self.logger = logging.getLogger(self.__class__.__name__)
@@ -206,7 +214,7 @@ class Factor(DataReader, EntityStateService, DataListener):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -257,7 +265,7 @@ class Factor(DataReader, EntityStateService, DataListener):
level,
category_field,
time_field,
computing_window,
keep_window,
)
EntityStateService.__init__(self, entity_ids=entity_ids)
@@ -279,7 +287,7 @@ class Factor(DataReader, EntityStateService, DataListener):
self.accumulator = self.__class__.accumulator
self.need_persist = need_persist
self.dry_run = only_compute_factor
self.only_compute_factor = only_compute_factor
#: 中间结果,不持久化
#: data_df->pipe_df
@@ -334,7 +342,7 @@ class Factor(DataReader, EntityStateService, DataListener):
super().load_data()
def load_factor(self):
if self.dry_run:
if self.only_compute_factor:
#: 如果只是为了计算因子只需要读取acc_window的factor_df
if self.accumulator is not None:
self.factor_df = self.load_window_df(
@@ -350,11 +358,14 @@ class Factor(DataReader, EntityStateService, DataListener):
index=[self.category_field, self.time_field],
)
self.decode_factor_df(self.factor_df)
def decode_factor_df(self, df):
col_map_object_hook = self.factor_col_map_object_hook()
if pd_is_not_null(self.factor_df) and col_map_object_hook:
if pd_is_not_null(df) and col_map_object_hook:
for col in col_map_object_hook:
if col in self.factor_df.columns:
self.factor_df[col] = self.factor_df[col].apply(
if col in df.columns:
df[col] = df[col].apply(
lambda x: json.loads(x, object_hook=col_map_object_hook.get(col)) if x else None
)
@@ -495,6 +506,48 @@ class Factor(DataReader, EntityStateService, DataListener):
self.result_df = self.result_df.reindex(new_index)
self.result_df = self.result_df.groupby(level=0).fillna(method=self.fill_method, limit=self.effective_number)
def update_entities(self, entity_ids):
if (self.entity_ids and entity_ids) and (set(self.entity_ids) == set(entity_ids)):
self.logger.info(f"current: {self.entity_ids}")
self.logger.info(f"refresh: {entity_ids}")
return
new_entity_ids = None
if entity_ids:
new_entity_ids = list(set(entity_ids) - set(self.entity_ids))
self.entity_ids = list(set(self.entity_ids + entity_ids))
if new_entity_ids:
self.logger.info(f"added new entity: {new_entity_ids}")
if not self.only_load_factor:
new_data_df = self.data_schema.query_data(
entity_ids=new_entity_ids,
provider=self.provider,
columns=self.columns,
start_timestamp=self.start_timestamp,
end_timestamp=self.end_timestamp,
filters=self.filters,
order=self.order,
limit=self.limit,
level=self.level,
index=[self.category_field, self.time_field],
time_field=self.time_field,
)
self.data_df = pd.concat([self.data_df, new_data_df], sort=False)
self.data_df.sort_index(level=[0, 1], inplace=True)
new_factor_df = get_data(
provider="zvt",
data_schema=self.factor_schema,
start_timestamp=self.start_timestamp,
entity_ids=new_entity_ids,
end_timestamp=self.end_timestamp,
index=[self.category_field, self.time_field],
)
self.decode_factor_df(new_factor_df)
self.factor_df = pd.concat([self.factor_df, new_factor_df], sort=False)
self.factor_df.sort_index(level=[0, 1], inplace=True)
def on_data_loaded(self, data: pd.DataFrame):
self.compute()
@@ -542,6 +595,66 @@ class Factor(DataReader, EntityStateService, DataListener):
else:
df_to_db(df=df, data_schema=self.factor_schema, provider="zvt", force_update=False)
def get_filter_df(self):
if is_filter_result_df(self.result_df):
return self.result_df[["filter_result"]]
def get_score_df(self):
if is_score_result_df(self.result_df):
return self.result_df[["score_result"]]
def get_targets(
self,
timestamp=None,
start_timestamp=None,
end_timestamp=None,
target_type: TargetType = TargetType.positive,
positive_threshold=0.8,
negative_threshold=-0.8,
):
if timestamp and (start_timestamp or end_timestamp):
raise ValueError("Use timestamp or (start_timestamp, end_timestamp)")
# select by filter
filter_df = self.get_filter_df()
selected_df = None
target_df = None
if pd_is_not_null(filter_df):
if target_type == TargetType.positive:
selected_df = filter_df[filter_df["filter_result"] == True]
elif target_type == TargetType.negative:
selected_df = filter_df[filter_df["filter_result"] == False]
else:
selected_df = filter_df[filter_df["filter_result"].isna()]
# select by score
score_df = self.get_score_df()
if pd_is_not_null(score_df):
if pd_is_not_null(selected_df):
# filter at first
score_df = score_df.loc[selected_df.index, :]
if target_type == TargetType.positive:
selected_df = score_df[score_df["score_result"] >= positive_threshold]
elif target_type == TargetType.negative:
selected_df = score_df[score_df["score_result"] <= negative_threshold]
else:
selected_df = score_df[
(score_df["score_result"] > negative_threshold) & (score_df["score"] < positive_threshold)
]
print(selected_df)
if pd_is_not_null(selected_df):
selected_df = selected_df.reset_index(level="entity_id")
if timestamp:
if to_pd_timestamp(timestamp) in selected_df.index:
target_df = selected_df.loc[[to_pd_timestamp(timestamp)], ["entity_id"]]
else:
target_df = selected_df.loc[
slice(to_pd_timestamp(start_timestamp), to_pd_timestamp(end_timestamp)), ["entity_id"]
]
if pd_is_not_null(target_df):
return target_df["entity_id"].tolist()
return []
class ScoreFactor(Factor):
scorer: Scorer = None

View File

@@ -59,7 +59,7 @@ class DataReader(Drawable):
level: IntervalLevel = None,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
) -> None:
self.logger = logging.getLogger(self.__class__.__name__)
@@ -94,7 +94,7 @@ class DataReader(Drawable):
self.category_field = category_field
self.time_field = time_field
self.computing_window = computing_window
self.computing_window = keep_window
self.category_col = eval("self.data_schema.{}".format(self.category_field))
self.time_col = eval("self.data_schema.{}".format(self.time_field))
@@ -182,13 +182,13 @@ class DataReader(Drawable):
:return:
:rtype:
"""
if not pd_is_not_null(self.data_df):
self.load_data()
return
start_time = time.time()
#: FIXME:we suppose history data should be there at first
has_got = []
dfs = []
changed = False
@@ -275,9 +275,9 @@ if __name__ == "__main__":
from zvt.domain import Stock1dKdata, Stock
data_reader = DataReader(
codes=["002572", "000338"],
data_schema=Stock1dKdata,
entity_schema=Stock,
codes=["002572", "000338"],
start_timestamp="2017-01-01",
end_timestamp="2019-06-10",
)

View File

@@ -29,7 +29,7 @@ class FinanceBaseFactor(Factor):
level: Union[str, IntervalLevel] = None,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -60,7 +60,7 @@ class FinanceBaseFactor(Factor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
@@ -110,7 +110,7 @@ class GoodCompanyFactor(FinanceBaseFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = True,
fill_method: str = "ffill",
effective_number: int = None,
@@ -153,7 +153,7 @@ class GoodCompanyFactor(FinanceBaseFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,

View File

@@ -41,7 +41,7 @@ class MaFactor(TechnicalFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -77,7 +77,7 @@ class MaFactor(TechnicalFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
@@ -124,7 +124,7 @@ class VolumeUpMaFactor(TechnicalFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -172,7 +172,7 @@ class VolumeUpMaFactor(TechnicalFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
@@ -248,34 +248,16 @@ class CrossMaVolumeFactor(VolumeUpMaFactor):
if __name__ == "__main__":
print("start")
parser = argparse.ArgumentParser()
parser.add_argument("--level", help="trading level", default="1d", choices=[item.value for item in IntervalLevel])
parser.add_argument("--start", help="start code", default="000001")
parser.add_argument("--end", help="end code", default="000005")
args = parser.parse_args()
level = IntervalLevel(args.level)
start = args.start
end = args.end
entities = get_entities(
provider="eastmoney",
entity_type="stock",
columns=[Stock.entity_id, Stock.code],
filters=[Stock.code >= start, Stock.code < end],
)
codes = entities.index.to_list()
factor = VolumeUpMaFactor(
entity_provider="em",
provider="em",
entity_ids=["stock_sz_000338"],
start_timestamp="2020-01-01",
end_timestamp=now_pd_timestamp(),
level=level,
need_persist=False,
)
print(factor.result_df)
selected = factor.get_targets(timestamp="2021-12-30")
print(selected)
# the __all__ is generated
__all__ = ["get_ma_factor_schema", "MaFactor", "CrossMaFactor", "VolumeUpMaFactor", "CrossMaVolumeFactor"]

View File

@@ -85,7 +85,7 @@ class MaStatsFactor(TechnicalFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -129,7 +129,7 @@ class MaStatsFactor(TechnicalFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
@@ -161,7 +161,7 @@ class TFactor(MaStatsFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -189,7 +189,7 @@ class TFactor(MaStatsFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,

View File

@@ -49,7 +49,7 @@ class TopBottomFactor(TechnicalFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -81,7 +81,7 @@ class TopBottomFactor(TechnicalFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
@@ -106,7 +106,7 @@ if __name__ == "__main__":
)
print(factor.factor_df)
data_reader1 = DataReader(codes=["601318"], data_schema=Stock1dKdata, entity_schema=Stock)
data_reader1 = DataReader(data_schema=Stock1dKdata, entity_schema=Stock, codes=["601318"])
drawer = Drawer(main_df=data_reader1.data_df, factor_df_list=[factor.factor_df[["top", "bottom"]]])
drawer.draw_kline(show=True)

View File

@@ -60,5 +60,11 @@ class GoldCrossFactor(MacdFactor):
self.result_df = s.to_frame(name="filter_result")
if __name__ == "__main__":
f = BullFactor(provider="em", entity_provider="em", entity_ids=["stock_sz_000338"])
print(f.data_df)
f.update_entity_ids(["stock_sz_000338", "stock_sh_600000"])
f.move_on()
print(f.data_df)
# the __all__ is generated
__all__ = ["MacdFactor", "BullFactor", "KeepBullFactor", "LiveOrDeadFactor", "GoldCrossFactor"]

View File

@@ -7,14 +7,13 @@ import pandas as pd
from pandas import DataFrame
from zvt.contract import IntervalLevel
from zvt.contract.drawer import Drawer
from zvt.contract.factor import Factor
from zvt.domain.meta.stock_meta import Stock
from zvt.utils.pd_utils import index_df, pd_is_not_null, is_filter_result_df, is_score_result_df
from zvt.utils.time_utils import to_pd_timestamp, now_pd_timestamp
class TargetType(Enum):
class TradeType(Enum):
# open_long 代表开多,并应该平掉相应标的的空单
open_long = "open_long"
# open_short 代表开空,并应该平掉相应标的的多单
@@ -129,12 +128,12 @@ class TargetSelector(object):
self.generate_targets()
def get_targets(self, timestamp, target_type: TargetType = TargetType.open_long) -> List[str]:
if target_type == TargetType.open_long:
def get_targets(self, timestamp, trade_type: TradeType = TradeType.open_long) -> List[str]:
if trade_type == TradeType.open_long:
df = self.open_long_df
elif target_type == TargetType.open_short:
elif trade_type == TradeType.open_short:
df = self.open_short_df
elif target_type == TargetType.keep:
elif trade_type == TradeType.keep:
df = self.keep_df
else:
assert False
@@ -146,13 +145,13 @@ class TargetSelector(object):
return []
def get_targets_between(
self, start_timestamp, end_timestamp, target_type: TargetType = TargetType.open_long
self, start_timestamp, end_timestamp, trade_type: TradeType = TradeType.open_long
) -> List[str]:
if target_type == TargetType.open_long:
if trade_type == TradeType.open_long:
df = self.open_long_df
elif target_type == TargetType.open_short:
elif trade_type == TradeType.open_short:
df = self.open_short_df
elif target_type == TargetType.keep:
elif trade_type == TradeType.keep:
df = self.keep_df
else:
assert False
@@ -163,10 +162,10 @@ class TargetSelector(object):
return []
def get_open_long_targets(self, timestamp):
return self.get_targets(timestamp=timestamp, target_type=TargetType.open_long)
return self.get_targets(timestamp=timestamp, trade_type=TradeType.open_long)
def get_open_short_targets(self, timestamp):
return self.get_targets(timestamp=timestamp, target_type=TargetType.open_short)
return self.get_targets(timestamp=timestamp, trade_type=TradeType.open_short)
# overwrite it to generate targets
def generate_targets(self):
@@ -214,31 +213,6 @@ class TargetSelector(object):
df = df.sort_values(by=["score", "entity_id"])
return df
def draw(
self,
render="html",
file_name=None,
width=None,
height=None,
title=None,
keep_ui_state=True,
annotation_df=None,
target_type: TargetType = TargetType.open_long,
):
if target_type == TargetType.open_long:
df = self.open_long_df.copy()
elif target_type == TargetType.open_short:
df = self.open_short_df.copy()
df["target_type"] = target_type.value
if pd_is_not_null(df):
df = df.reset_index(drop=False)
drawer = Drawer(df)
drawer.draw_table(width=width, height=height, title=title, keep_ui_state=keep_ui_state)
# the __all__ is generated
__all__ = ["TargetType", "SelectMode", "TargetSelector"]
__all__ = ["TradeType", "SelectMode", "TargetSelector"]

View File

@@ -26,7 +26,7 @@ class TechnicalFactor(Factor, metaclass=FactorMeta):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -84,7 +84,7 @@ class TechnicalFactor(Factor, metaclass=FactorMeta):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,

View File

@@ -479,7 +479,7 @@ class ZFactor(TechnicalFactor):
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
@@ -509,7 +509,7 @@ class ZFactor(TechnicalFactor):
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,

45
src/zvt/main.py Normal file
View File

@@ -0,0 +1,45 @@
import dash_bootstrap_components as dbc
from dash import html
from dash.dependencies import Input, Output
from zvt import init_plugins
from zvt.ui import zvt_app
from zvt.ui.apps import factor_app
def serve_layout():
layout = html.Div(
children=[
# banner
html.Div(className="zvt-banner", children=html.H2(className="h2-title", children="ZVT")),
dbc.CardHeader(
dbc.Tabs(
[dbc.Tab(label="factor", tab_id="tab-factor", label_style={}, tab_style={"width": "100px"})],
id="card-tabs",
active_tab="tab-factor",
)
),
dbc.CardBody(html.P(id="card-content", className="card-text")),
]
)
return layout
@zvt_app.callback(Output("card-content", "children"), [Input("card-tabs", "active_tab")])
def tab_content(active_tab):
if "tab-factor" == active_tab:
return factor_app.factor_layout()
zvt_app.layout = serve_layout
def main():
init_plugins()
zvt_app.run_server(debug=True)
# zvt_app.run_server()
if __name__ == "__main__":
main()

View File

@@ -1,26 +1,15 @@
# -*- coding: utf-8 -*-
from zvt.contract import IntervalLevel
from zvt.factors.target_selector import TargetSelector
from zvt.factors.ma.ma_factor import CrossMaFactor
from zvt.factors import BullFactor
from zvt.factors.ma.ma_factor import CrossMaFactor
from zvt.trader.trader import StockTrader
class MyMaTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
)
myselector.add_factor(
return [
CrossMaFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
@@ -32,26 +21,14 @@ class MyMaTrader(StockTrader):
need_persist=False,
adjust_type=adjust_type,
)
)
self.selectors.append(myselector)
]
class MyBullTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
)
myselector.add_factor(
return [
BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
@@ -61,9 +38,7 @@ class MyBullTrader(StockTrader):
end_timestamp=end_timestamp,
adjust_type=adjust_type,
)
)
self.selectors.append(myselector)
]
if __name__ == "__main__":

View File

@@ -8,12 +8,12 @@ import pandas as pd
from zvt.api.trader_info_api import AccountStatsReader
from zvt.contract import IntervalLevel, TradableEntity, AdjustType
from zvt.contract.drawer import Drawer
from zvt.contract.factor import Factor, TargetType
from zvt.contract.normal_data import NormalData
from zvt.domain import Stock, AccountStats, Position
from zvt.factors.target_selector import TargetSelector
from zvt.trader import TradingSignal, TradingSignalType, TradingListener
from zvt.trader.account import SimAccountService
from zvt.utils.time_utils import to_pd_timestamp, now_pd_timestamp, to_time_str, is_same_date
from zvt.utils.time_utils import to_pd_timestamp, now_pd_timestamp, to_time_str, is_same_date, next_date
class Trader(object):
@@ -52,7 +52,7 @@ class Trader(object):
self.exchanges = exchanges
self.codes = codes
self.provider = provider
# make sure the min level selector correspond to the provider and level
# make sure the min level factor correspond to the provider and level
self.level = IntervalLevel(level)
self.real_time = real_time
self.start_timestamp = to_pd_timestamp(start_timestamp)
@@ -68,6 +68,8 @@ class Trader(object):
)
assert self.end_timestamp >= now_pd_timestamp()
# false: 收到k线时该k线已完成
# true: 收到k线时该k线可能未完成
self.kdata_use_begin_time = kdata_use_begin_time
self.draw_result = draw_result
self.rich_mode = rich_mode
@@ -80,7 +82,6 @@ class Trader(object):
self.level_map_short_targets = {}
self.trading_signals: List[TradingSignal] = []
self.trading_signal_listeners: List[TradingListener] = []
self.selectors: List[TargetSelector] = []
self.account_service = SimAccountService(
entity_schema=self.entity_schema,
@@ -95,21 +96,21 @@ class Trader(object):
self.register_trading_signal_listener(self.account_service)
self.init_selectors(
self.factors = self.init_factors(
entity_ids=self.entity_ids,
entity_schema=self.entity_schema,
exchanges=self.exchanges,
codes=self.codes,
start_timestamp=self.start_timestamp,
start_timestamp=next_date(self.start_timestamp, -365),
end_timestamp=self.end_timestamp,
adjust_type=self.adjust_type,
)
if self.selectors:
self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.selectors]))
if self.factors:
self.trading_level_asc = list(set([IntervalLevel(selector.level) for selector in self.factors]))
self.trading_level_asc.sort()
self.logger.info(f"trader level:{self.level},selectors level:{self.trading_level_asc}")
self.logger.info(f"trader level:{self.level},factors level:{self.trading_level_asc}")
if self.level != self.trading_level_asc[0]:
raise Exception("trader level should be the min of the selectors")
@@ -117,20 +118,19 @@ class Trader(object):
self.trading_level_desc = list(self.trading_level_asc)
self.trading_level_desc.reverse()
# run selectors for history data at first
for selector in self.selectors:
selector.run()
self.on_init()
self.on_start()
def on_start(self):
def on_init(self):
self.logger.info(f"trader:{self.trader_name} on_start")
def init_selectors(
def on_refresh_entities(self, timestamp):
self.logger.info(f"trader: {self.trader_name} timestamp: {timestamp} on_refresh_entities")
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
"""
overwrite it to init selectors if you want to use selector/factor computing model
overwrite it to init factors if you want to use factor computing model
:param adjust_type:
"""
@@ -316,14 +316,14 @@ class Trader(object):
drawer.draw_line(show=True)
def on_targets_filtered(
self, timestamp, level, selector: TargetSelector, long_targets: List[str], short_targets: List[str]
self, timestamp, level, factor: Factor, long_targets: List[str], short_targets: List[str]
) -> Tuple[List[str], List[str]]:
"""
overwrite it to filter the targets from selector
:param timestamp: the event time
:param level: the level
:param selector: the selector
:param factor: the factor
:param long_targets: the long targets from the selector
:param short_targets: the short targets from the selector
:return: filtered long targets, filtered short targets
@@ -385,6 +385,9 @@ class Trader(object):
if not self.in_trading_date(timestamp=timestamp):
continue
for factor in self.factors:
factor.update_entities(entity_ids=self.entity_ids)
waiting_seconds = 0
if self.level == IntervalLevel.LEVEL_1DAY:
@@ -398,7 +401,7 @@ class Trader(object):
break
elif self.real_time:
# all selector move on to handle the coming data
# all factor move on to handle the coming data
if self.kdata_use_begin_time:
real_end_timestamp = timestamp + pd.Timedelta(seconds=self.level.to_second())
else:
@@ -409,25 +412,25 @@ class Trader(object):
# meaning the future kdata not ready yet,we could move on to check
if waiting_seconds > 0:
# iterate the selector from min to max which in finished timestamp kdata
# iterate the factor from min to max which in finished timestamp kdata
for level in self.trading_level_asc:
if self.entity_schema.is_finished_kdata_timestamp(timestamp=timestamp, level=level):
for selector in self.selectors:
if selector.level == level:
selector.move_on(timestamp, self.kdata_use_begin_time, timeout=waiting_seconds + 20)
for factor in self.factors:
if factor.level == level:
factor.move_on(to_timestamp=timestamp, timeout=waiting_seconds + 20)
# on_trading_open to setup the account
# on_trading_open to set the account
if self.level >= IntervalLevel.LEVEL_1DAY or (
self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_open_timestamp(timestamp)
):
self.on_trading_open(timestamp)
self.on_trading_open(timestamp=timestamp)
self.on_time(timestamp=timestamp)
# 一般来说selector(factors)计算 多标的 历史数据比较快,多级别的计算也比较方便,常用于全市场标的粗过滤
# 一般来说factor计算 多标的 历史数据比较快,多级别的计算也比较方便,常用于全市场标的粗过滤
# 更细节的控制可以在on_targets_filtered里进一步处理
# 也可以在on_time里面设计一些自己的逻辑配合过滤
if self.selectors:
if self.factors:
# 多级别的遍历算法要点:
# 1)计算各级别的 标的,通过 on_targets_filtered 过滤缓存在level_map_long_targetslevel_map_short_targets
# 2)在最小的level通过 on_targets_selected_from_levels 根据多级别的缓存标的,生成最终的选中标的
@@ -439,16 +442,16 @@ class Trader(object):
all_short_targets = []
# 从该level的selector中过滤targets
for selector in self.selectors:
if selector.level == level:
long_targets = selector.get_open_long_targets(timestamp=timestamp)
short_targets = selector.get_open_short_targets(timestamp=timestamp)
for factor in self.factors:
if factor.level == level:
long_targets = factor.get_targets(timestamp=timestamp, target_type=TargetType.positive)
short_targets = factor.get_targets(timestamp=timestamp, target_type=TargetType.negative)
if long_targets or short_targets:
long_targets, short_targets = self.on_targets_filtered(
timestamp=timestamp,
level=level,
selector=selector,
factor=factor,
long_targets=long_targets,
short_targets=short_targets,
)
@@ -501,6 +504,7 @@ class Trader(object):
self.level != IntervalLevel.LEVEL_1DAY and self.entity_schema.is_close_timestamp(timestamp)
):
self.on_trading_close(timestamp)
self.on_refresh_entities(timestamp=timestamp)
self.on_finish(timestamp)

View File

@@ -1,2 +1,18 @@
# -*- coding: utf-8 -*-
# Plotly is enough for research, so remove main dash UI
import os
import dash
import dash_bootstrap_components as dbc
assets_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "assets"))
zvt_app = dash.Dash(
__name__,
meta_tags=[{"name": "viewport", "content": "width=device-width"}],
assets_folder=assets_path,
external_stylesheets=[dbc.themes.BOOTSTRAP],
)
zvt_app.config.suppress_callback_exceptions = True
server = zvt_app.server

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,341 @@
# -*- coding: utf-8 -*-
from typing import List
from dash import dcc
import dash_daq as daq
from dash import html
from dash import dash
from dash.dependencies import Input, Output, State
from zvt.api.trader_info_api import AccountStatsReader, OrderReader, get_order_securities
from zvt.api.trader_info_api import get_trader_info
from zvt.contract import Mixin
from zvt.contract import zvt_context, IntervalLevel
from zvt.contract.api import get_entities, get_schema_by_name, get_schema_columns
from zvt.contract.drawer import StackedDrawer
from zvt.domain import TraderInfo
from zvt.ui import zvt_app
from zvt.ui.components.dcc_components import get_account_stats_figure
from zvt.utils import pd_is_not_null
account_readers = []
order_readers = []
# init the data
traders: List[TraderInfo] = []
trader_names: List[str] = []
def order_type_flag(order_type):
if order_type == "order_long" or order_type == "order_close_short":
return "B"
else:
return "S"
def order_type_color(order_type):
if order_type == "order_long" or order_type == "order_close_short":
return "#ec0000"
else:
return "#00da3c"
def load_traders():
global traders
global trader_names
traders = get_trader_info(return_type="domain")
account_readers.clear()
order_readers.clear()
for trader in traders:
account_readers.append(AccountStatsReader(level=trader.level, trader_names=[trader.trader_name]))
order_readers.append(
OrderReader(start_timestamp=trader.start_timestamp, level=trader.level, trader_names=[trader.trader_name])
)
trader_names = [item.trader_name for item in traders]
load_traders()
def factor_layout():
layout = html.Div(
[
# controls
html.Div(
className="three columns card",
children=[
html.Div(
className="bg-white user-control",
children=[
html.Div(
className="padding-top-bot",
children=[
html.H6("select trader:"),
dcc.Dropdown(
id="trader-selector",
placeholder="select the trader",
options=[{"label": item, "value": i} for i, item in enumerate(trader_names)],
),
],
),
# select entity type
html.Div(
className="padding-top-bot",
children=[
html.H6("select entity type:"),
dcc.Dropdown(
id="entity-type-selector",
placeholder="select entity type",
options=[
{"label": name, "value": name}
for name in zvt_context.tradable_schema_map.keys()
],
value="stock",
clearable=False,
),
],
),
# select entity provider
html.Div(
className="padding-top-bot",
children=[
html.H6("select entity provider:"),
dcc.Dropdown(id="entity-provider-selector", placeholder="select entity provider"),
],
),
# select entity
html.Div(
className="padding-top-bot",
children=[
html.H6("select entity:"),
dcc.Dropdown(id="entity-selector", placeholder="select entity"),
],
),
# select levels
html.Div(
className="padding-top-bot",
children=[
html.H6("select levels:"),
dcc.Dropdown(
id="levels-selector",
options=[
{"label": level.name, "value": level.value}
for level in (IntervalLevel.LEVEL_1WEEK, IntervalLevel.LEVEL_1DAY)
],
value="1d",
multi=True,
),
],
),
# select factor
html.Div(
className="padding-top-bot",
children=[
html.H6("select factor:"),
dcc.Dropdown(
id="factor-selector",
placeholder="select factor",
options=[
{"label": name, "value": name}
for name in zvt_context.factor_cls_registry.keys()
],
value="TechnicalFactor",
),
],
),
# select data
html.Div(
children=[
html.Div(
[
html.H6(
"related/all data to show in sub graph",
style={"display": "inline-block"},
),
daq.BooleanSwitch(
id="data-switch",
on=True,
style={
"display": "inline-block",
"float": "right",
"vertical-align": "middle",
"padding": "8px",
},
),
],
),
dcc.Dropdown(id="data-selector", placeholder="schema"),
],
style={"padding-top": "12px"},
),
# select properties
html.Div(
children=[dcc.Dropdown(id="schema-column-selector", placeholder="properties")],
style={"padding-top": "6px"},
),
],
)
],
),
# Graph
html.Div(
className="nine columns card-left",
children=[
html.Div(
id="trader-details",
className="bg-white",
),
html.Div(id="factor-details"),
],
),
]
)
return layout
@zvt_app.callback(
[
Output("trader-details", "children"),
Output("entity-type-selector", "options"),
Output("entity-provider-selector", "options"),
Output("entity-selector", "options"),
],
[
Input("trader-selector", "value"),
Input("entity-type-selector", "value"),
Input("entity-provider-selector", "value"),
],
)
def update_trader_details(trader_index, entity_type, entity_provider):
if trader_index is not None:
# change entity_type options
entity_type = traders[trader_index].entity_type
if not entity_type:
entity_type = "stock"
entity_type_options = [{"label": entity_type, "value": entity_type}]
# account stats
account_stats = get_account_stats_figure(account_stats_reader=account_readers[trader_index])
providers = zvt_context.tradable_schema_map.get(entity_type).providers
entity_provider_options = [{"label": name, "value": name} for name in providers]
# entities
entity_ids = get_order_securities(trader_name=trader_names[trader_index])
df = get_entities(
provider=entity_provider,
entity_type=entity_type,
entity_ids=entity_ids,
columns=["entity_id", "code", "name"],
index="entity_id",
)
entity_options = [
{"label": f'{entity_id}({entity["name"]})', "value": entity_id} for entity_id, entity in df.iterrows()
]
return account_stats, entity_type_options, entity_provider_options, entity_options
else:
entity_type_options = [{"label": name, "value": name} for name in zvt_context.tradable_schema_map.keys()]
account_stats = None
providers = zvt_context.tradable_schema_map.get(entity_type).providers
entity_provider_options = [{"label": name, "value": name} for name in providers]
df = get_entities(
provider=entity_provider, entity_type=entity_type, columns=["entity_id", "code", "name"], index="entity_id"
)
entity_options = [
{"label": f'{entity_id}({entity["name"]})', "value": entity_id} for entity_id, entity in df.iterrows()
]
return account_stats, entity_type_options, entity_provider_options, entity_options
@zvt_app.callback(
Output("data-selector", "options"), [Input("entity-type-selector", "value"), Input("data-switch", "on")]
)
def update_entity_selector(entity_type, related):
if entity_type is not None:
if related:
schemas = zvt_context.entity_map_schemas.get(entity_type)
else:
schemas = zvt_context.schemas
return [{"label": schema.__name__, "value": schema.__name__} for schema in schemas]
raise dash.PreventUpdate()
@zvt_app.callback(Output("schema-column-selector", "options"), [Input("data-selector", "value")])
def update_column_selector(schema_name):
if schema_name:
schema = get_schema_by_name(name=schema_name)
cols = get_schema_columns(schema=schema)
return [{"label": col, "value": col} for col in cols]
raise dash.PreventUpdate()
@zvt_app.callback(
Output("factor-details", "children"),
[
Input("factor-selector", "value"),
Input("entity-type-selector", "value"),
Input("entity-selector", "value"),
Input("levels-selector", "value"),
Input("schema-column-selector", "value"),
],
[State("trader-selector", "value"), State("data-selector", "value")],
)
def update_factor_details(factor, entity_type, entity, levels, columns, trader_index, schema_name):
if factor and entity_type and entity and levels:
sub_df = None
# add sub graph
if columns:
if type(columns) == str:
columns = [columns]
columns = columns + ["entity_id", "timestamp"]
schema: Mixin = get_schema_by_name(name=schema_name)
sub_df = schema.query_data(entity_id=entity, columns=columns)
# add trading signals as annotation
annotation_df = None
if trader_index is not None:
order_reader = order_readers[trader_index]
annotation_df = order_reader.data_df.copy()
annotation_df = annotation_df[annotation_df.entity_id == entity].copy()
if pd_is_not_null(annotation_df):
annotation_df["value"] = annotation_df["order_price"]
annotation_df["flag"] = annotation_df["order_type"].apply(lambda x: order_type_flag(x))
annotation_df["color"] = annotation_df["order_type"].apply(lambda x: order_type_color(x))
print(annotation_df.tail())
if type(levels) is list and len(levels) >= 2:
levels.sort()
drawers = []
for level in levels:
drawers.append(
zvt_context.factor_cls_registry[factor](
entity_schema=zvt_context.tradable_schema_map[entity_type], level=level, entity_ids=[entity]
).drawer()
)
stacked = StackedDrawer(*drawers)
return dcc.Graph(id=f"{factor}-{entity_type}-{entity}", figure=stacked.draw_kline(show=False, height=900))
else:
if type(levels) is list:
level = levels[0]
else:
level = levels
drawer = zvt_context.factor_cls_registry[factor](
entity_schema=zvt_context.tradable_schema_map[entity_type],
level=level,
entity_ids=[entity],
need_persist=False,
).drawer()
if pd_is_not_null(sub_df):
drawer.add_sub_df(sub_df)
if pd_is_not_null(annotation_df):
drawer.annotation_df = annotation_df
return dcc.Graph(id=f"{factor}-{entity_type}-{entity}", figure=drawer.draw_kline(show=False, height=800))
raise dash.PreventUpdate()

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

414
src/zvt/ui/assets/base.css Normal file
View File

@@ -0,0 +1,414 @@
/* Table of contents
- Plotly.js
- Grid
- Base Styles
- Typography
- Links
- Buttons
- Forms
- Lists
- Code
- Tables
- Spacing
- Utilities
- Clearing
- Media Queries
*/
/* PLotly.js
*/
/* plotly.js's modebar's z-index is 1001 by default
* https://github.com/plotly/plotly.js/blob/7e4d8ab164258f6bd48be56589dacd9bdd7fded2/src/css/_modebar.scss#L5
* In case a dropdown is above the graph, the dropdown's options
* will be rendered below the modebar
* Increase the select option's z-index
*/
/* This was actually not quite right -
dropdowns were overlapping each other (edited October 26)
.Select {
z-index: 1002;
}*/
/* Grid
*/
.container {
position: relative;
width: 100%;
max-width: 960px;
margin: 0 auto;
padding: 0 20px;
box-sizing: border-box; }
.column,
.columns {
width: 100%;
float: left;
box-sizing: border-box; }
/* For devices larger than 400px */
@media (min-width: 400px) {
.container {
width: 85%;
padding: 0; }
}
/* For devices larger than 550px */
@media (min-width: 550px) {
.container {
width: 80%; }
.column,
.columns {
margin-left: 4%; }
.column:first-child,
.columns:first-child {
margin-left: 0; }
.one.column,
.one.columns { width: 4.66666666667%; }
.two.columns { width: 13.3333333333%; }
.three.columns { width: 22%; }
.four.columns { width: 30.6666666667%; }
.five.columns { width: 39.3333333333%; }
.six.columns { width: 48%; }
.seven.columns { width: 56.6666666667%; }
.eight.columns { width: 65.3333333333%; }
.nine.columns { width: 74.0%; }
.ten.columns { width: 82.6666666667%; }
.eleven.columns { width: 91.3333333333%; }
.twelve.columns { width: 100%; margin-left: 0; }
.one-third.column { width: 30.6666666667%; }
.two-thirds.column { width: 65.3333333333%; }
.one-half.column { width: 48%; }
/* Offsets */
.offset-by-one.column,
.offset-by-one.columns { margin-left: 8.66666666667%; }
.offset-by-two.column,
.offset-by-two.columns { margin-left: 17.3333333333%; }
.offset-by-three.column,
.offset-by-three.columns { margin-left: 26%; }
.offset-by-four.column,
.offset-by-four.columns { margin-left: 34.6666666667%; }
.offset-by-five.column,
.offset-by-five.columns { margin-left: 43.3333333333%; }
.offset-by-six.column,
.offset-by-six.columns { margin-left: 52%; }
.offset-by-seven.column,
.offset-by-seven.columns { margin-left: 60.6666666667%; }
.offset-by-eight.column,
.offset-by-eight.columns { margin-left: 69.3333333333%; }
.offset-by-nine.column,
.offset-by-nine.columns { margin-left: 78.0%; }
.offset-by-ten.column,
.offset-by-ten.columns { margin-left: 86.6666666667%; }
.offset-by-eleven.column,
.offset-by-eleven.columns { margin-left: 95.3333333333%; }
.offset-by-one-third.column,
.offset-by-one-third.columns { margin-left: 34.6666666667%; }
.offset-by-two-thirds.column,
.offset-by-two-thirds.columns { margin-left: 69.3333333333%; }
.offset-by-one-half.column,
.offset-by-one-half.columns { margin-left: 52%; }
}
/* Base Styles
*/
/* NOTE
html is set to 62.5% so that all the REM measurements throughout Skeleton
are based on 10px sizing. So basically 1.5rem = 15px :) */
html {
font-size: 62.5%; }
body {
font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */
line-height: 1.6;
font-weight: 400;
font-family: "Open Sans", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif;
color: rgb(50, 50, 50); }
/* Typography
*/
h1, h2, h3, h4, h5, h6 {
margin-top: 0;
margin-bottom: 0;
font-weight: 300; }
h1 { font-size: 4.5rem; line-height: 1.2; letter-spacing: -.1rem; margin-bottom: 2rem; }
h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; margin-bottom: 1.8rem; margin-top: 1.8rem;}
h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; margin-bottom: 1.5rem; margin-top: 1.5rem;}
h4 { font-size: 2.6rem; line-height: 1.35; letter-spacing: -.08rem; margin-bottom: 1.2rem; margin-top: 1.2rem;}
h5 { font-size: 2.2rem; line-height: 1.5; letter-spacing: -.05rem; margin-bottom: 0.6rem; margin-top: 0.6rem;}
h6 { font-size: 2.0rem; line-height: 1.6; letter-spacing: 0; margin-bottom: 0.75rem; margin-top: 0.75rem;}
p {
margin-top: 0; }
/* Blockquotes
*/
blockquote {
border-left: 4px lightgrey solid;
padding-left: 1rem;
margin-top: 2rem;
margin-bottom: 2rem;
margin-left: 0rem;
}
/* Links
*/
a {
color: #1EAEDB;
text-decoration: underline;
cursor: pointer;}
a:hover {
color: #0FA0CE; }
/* Buttons
*/
.button,
button,
input[type="submit"],
input[type="reset"],
input[type="button"] {
display: inline-block;
height: 38px;
padding: 0 30px;
color: #555;
text-align: center;
font-size: 11px;
font-weight: 600;
line-height: 38px;
letter-spacing: .1rem;
text-transform: uppercase;
text-decoration: none;
white-space: nowrap;
background-color: transparent;
border-radius: 4px;
border: 1px solid #bbb;
cursor: pointer;
box-sizing: border-box; }
.button:hover,
button:hover,
input[type="submit"]:hover,
input[type="reset"]:hover,
input[type="button"]:hover,
.button:focus,
button:focus,
input[type="submit"]:focus,
input[type="reset"]:focus,
input[type="button"]:focus {
color: #333;
border-color: #888;
outline: 0; }
.button.button-primary,
button.button-primary,
input[type="submit"].button-primary,
input[type="reset"].button-primary,
input[type="button"].button-primary {
color: #FFF;
background-color: #33C3F0;
border-color: #33C3F0; }
.button.button-primary:hover,
button.button-primary:hover,
input[type="submit"].button-primary:hover,
input[type="reset"].button-primary:hover,
input[type="button"].button-primary:hover,
.button.button-primary:focus,
button.button-primary:focus,
input[type="submit"].button-primary:focus,
input[type="reset"].button-primary:focus,
input[type="button"].button-primary:focus {
color: #FFF;
background-color: #1EAEDB;
border-color: #1EAEDB; }
/* Forms
*/
input[type="email"],
input[type="number"],
input[type="search"],
input[type="text"],
input[type="tel"],
input[type="url"],
input[type="password"],
textarea,
select {
height: 38px;
padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */
background-color: #fff;
border: 1px solid #D1D1D1;
border-radius: 4px;
box-shadow: none;
box-sizing: border-box;
font-family: inherit;
font-size: inherit; /*https://stackoverflow.com/questions/6080413/why-doesnt-input-inherit-the-font-from-body*/}
/* Removes awkward default styles on some inputs for iOS */
input[type="email"],
input[type="number"],
input[type="search"],
input[type="text"],
input[type="tel"],
input[type="url"],
input[type="password"],
textarea {
-webkit-appearance: none;
-moz-appearance: none;
appearance: none; }
textarea {
min-height: 65px;
padding-top: 6px;
padding-bottom: 6px; }
input[type="email"]:focus,
input[type="number"]:focus,
input[type="search"]:focus,
input[type="text"]:focus,
input[type="tel"]:focus,
input[type="url"]:focus,
input[type="password"]:focus,
textarea:focus,
select:focus {
border: 1px solid #33C3F0;
outline: 0; }
label,
legend {
display: block;
margin-bottom: 0px; }
fieldset {
padding: 0;
border-width: 0; }
input[type="checkbox"],
input[type="radio"] {
display: inline; }
label > .label-body {
display: inline-block;
margin-left: .5rem;
font-weight: normal; }
/* Lists
*/
ul {
list-style: circle inside; }
ol {
list-style: decimal inside; }
ol, ul {
padding-left: 0;
margin-top: 0; }
ul ul,
ul ol,
ol ol,
ol ul {
margin: 1.5rem 0 1.5rem 3rem;
font-size: 90%; }
li {
margin-bottom: 1rem; }
/* Tables
*/
table {
border-collapse: collapse;
}
th,
td {
padding: 12px 15px;
text-align: left;
border-bottom: 1px solid #E1E1E1; }
th:first-child,
td:first-child {
padding-left: 0; }
th:last-child,
td:last-child {
padding-right: 0; }
/* Spacing
*/
button,
.button {
margin-bottom: 0rem; }
input,
textarea,
select,
fieldset {
margin-bottom: 0rem; }
pre,
dl,
figure,
table,
form {
margin-bottom: 0rem; }
p,
ul,
ol {
margin-bottom: 0.75rem; }
/* Utilities
*/
.u-full-width {
width: 100%;
box-sizing: border-box; }
.u-max-full-width {
max-width: 100%;
box-sizing: border-box; }
.u-pull-right {
float: right; }
.u-pull-left {
float: left; }
/* Misc
*/
hr {
margin-top: 3rem;
margin-bottom: 3.5rem;
border-width: 0;
border-top: 1px solid #E1E1E1; }
/* Clearing
*/
/* Self Clearing Goodness */
.container:after,
.row:after,
.u-cf {
content: "";
display: table;
clear: both; }
/* Media Queries
*/
/*
Note: The best way to structure the use of media queries is to create the queries
near the relevant code. For example, if you wanted to change the styles for buttons
on small devices, paste the mobile query code up in the buttons section and style it
there.
*/
/* Larger than mobile */
@media (min-width: 400px) {}
/* Larger than phablet (also point when grid becomes active) */
@media (min-width: 550px) {}
/* Larger than tablet */
@media (min-width: 750px) {}
/* Larger than desktop */
@media (min-width: 1000px) {}
/* Larger than Desktop HD */
@media (min-width: 1200px) {}

View File

@@ -0,0 +1,173 @@
/*Fonts */
@import url('https://fonts.googleapis.com/css?family=Roboto&display=swap');
body {
margin: 0px;
padding: 0px;
background-color: #F3F4F9;
font-family: 'Roboto';
color: #203cb3;
}
.zvt-banner {
color: #7f7f7f;
font-weight: 600;
font-size: 20px;
background: #fafbfc;
padding: 12px;
padding-left: 24px;
border-bottom: 2px solid lightgray;
}
.zvt-nav {
color: #7f7f7f;
background: #fafbfc;
padding: 12px;
padding-left: 36px;
}
.div-logo {
display: inline-block;
float: right;
}
.logo {
height: 35px;
padding: 6px;
margin-top: 3px;
}
.h2-title, .h2-title-mobile {
font-family: 'Roboto';
display: inline-block;
letter-spacing: 3.8px;
font-weight: 800;
font-size: 20px;
}
.h2-title-mobile {
display: none;
}
h5, h6 {
font-family: 'Roboto';
font-weight: 600;
font-size: 16px;
}
h5 {
padding-left: 42px;
}
.alert {
padding: 20px;
background-color: #f44336;
color: white;
}
.bg-white {
background-color: white;
padding: 24px 32px;
}
.card {
padding: 24px 12px 24px 12px;
margin-left: 4%;
}
.card-left {
padding: 24px 12px 24px 12px;
margin-left: 0px;
}
.padding-top-bot {
padding-top: 12px;
padding-bottom: 18px;
}
.upload {
width: 100%;
line-height: 60px;
border-width: 1px;
border-style: dashed;
border-radius: 5px;
text-align: center;
}
.upload p, .upload a {
display: inline;
}
.Select-control {
border: 1px solid #203cb3;
}
@media only screen and (max-width: 320px) {
.Select-menu-outer, .Select-value {
font-size: 10.5px;
}
.upload {
padding: 5px;
}
}
/* mobile */
@media only screen and (max-width: 768px) {
.upload {
line-height: 60px;
border-width: 1px;
border-style: dashed;
border-radius: 5px;
text-align: center;
font-size: small;
}
.columns {
width: 100%;
}
.card, .card-left {
padding: 12px;
margin: 0px;
}
.bg-white {
height: auto;
}
.logo {
height: 28px;
padding-left: 0px;
padding-bottom: 0px;
}
.div-logo {
float: left;
display: block;
width: 100%;
}
.h2-title {
display: none;
}
.h2-title-mobile {
display: block;
float: left;
}
.app-body {
margin-left: 0px;
}
.columns {
text-align: center;
}
.user-control {
padding-top: 24px;
padding-bottom: 24px;
}
}

View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
from dash import dcc
from zvt.api.kdata import get_kdata_schema
from zvt.api.trader_info_api import OrderReader, AccountStatsReader
from zvt.contract.api import decode_entity_id
from zvt.contract.drawer import Drawer
from zvt.contract.reader import DataReader
from zvt.contract import zvt_context
from zvt.utils.pd_utils import pd_is_not_null
def order_type_color(order_type):
if order_type == "order_long" or order_type == "order_close_short":
return "#ec0000"
else:
return "#00da3c"
def order_type_flag(order_type):
if order_type == "order_long" or order_type == "order_close_short":
return "B"
else:
return "S"
def get_trading_signals_figure(
order_reader: OrderReader, entity_id: str, start_timestamp=None, end_timestamp=None, adjust_type=None
):
entity_type, _, _ = decode_entity_id(entity_id)
data_schema = get_kdata_schema(entity_type=entity_type, level=order_reader.level, adjust_type=adjust_type)
if not start_timestamp:
start_timestamp = order_reader.start_timestamp
if not end_timestamp:
end_timestamp = order_reader.end_timestamp
kdata_reader = DataReader(
data_schema=data_schema,
entity_schema=zvt_context.tradable_schema_map.get(entity_type),
entity_ids=[entity_id],
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
level=order_reader.level,
)
# generate the annotation df
order_reader.move_on(timeout=0)
df = order_reader.data_df.copy()
df = df[df.entity_id == entity_id].copy()
if pd_is_not_null(df):
df["value"] = df["order_price"]
df["flag"] = df["order_type"].apply(lambda x: order_type_flag(x))
df["color"] = df["order_type"].apply(lambda x: order_type_color(x))
print(df.tail())
drawer = Drawer(main_df=kdata_reader.data_df, annotation_df=df)
return drawer.draw_kline(show=False, height=800)
def get_account_stats_figure(account_stats_reader: AccountStatsReader):
graph_list = []
# 账户统计曲线
if account_stats_reader:
fig = account_stats_reader.draw_line(show=False)
for trader_name in account_stats_reader.trader_names:
graph_list.append(dcc.Graph(id="{}-account".format(trader_name), figure=fig))
return graph_list

View File

@@ -15,12 +15,12 @@ from zvt.contract import IntervalLevel
def test_china_stock_reader():
data_reader = DataReader(
codes=["002572", "000338"],
data_schema=Stock1dKdata,
entity_schema=Stock,
entity_provider="eastmoney",
codes=["002572", "000338"],
start_timestamp="2019-01-01",
end_timestamp="2019-06-10",
entity_provider="eastmoney",
)
categories = data_reader.data_df.index.levels[0].to_list()
@@ -48,12 +48,12 @@ def test_china_stock_reader():
def test_reader_move_on():
data_reader = DataReader(
codes=["002572", "000338"],
data_schema=Stock1dKdata,
entity_schema=Stock,
entity_provider="eastmoney",
codes=["002572", "000338"],
start_timestamp="2019-06-13",
end_timestamp="2019-06-14",
entity_provider="eastmoney",
)
data_reader.move_on(to_timestamp="2019-06-15")

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
from zvt.contract import IntervalLevel
from zvt.factors.ma.ma_factor import CrossMaFactor
from zvt.contract.factor import TargetType
from zvt.factors import BullFactor
from ..context import init_test_context
init_test_context()
def test_cross_ma_select_targets():
entity_ids = ["stock_sz_000338"]
start_timestamp = "2018-01-01"
end_timestamp = "2019-06-30"
factor = CrossMaFactor(
entity_ids=entity_ids,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
keep_window=10,
windows=[5, 10],
need_persist=False,
level=IntervalLevel.LEVEL_1DAY,
adjust_type="hfq",
)
assert "stock_sz_000338" in factor.get_targets(timestamp="2018-01-19")
def test_bull_select_targets():
factor = BullFactor(
start_timestamp="2019-01-01", end_timestamp="2019-06-10", level=IntervalLevel.LEVEL_1DAY, provider="joinquant"
)
targets = factor.get_targets(timestamp="2019-05-08", target_type=TargetType.positive)
assert "stock_sz_000338" not in targets
assert "stock_sz_002572" not in targets
targets = factor.get_targets("2019-05-08", target_type=TargetType.negative)
assert "stock_sz_000338" in targets
assert "stock_sz_002572" not in targets
factor.move_on(timeout=0)
targets = factor.get_targets(timestamp="2019-06-19", target_type=TargetType.positive)
assert "stock_sz_000338" in targets
assert "stock_sz_002572" not in targets

View File

@@ -1,83 +0,0 @@
# -*- coding: utf-8 -*-
from zvt.contract import IntervalLevel
from zvt.factors.target_selector import TargetSelector
from zvt.factors.ma.ma_factor import CrossMaFactor
from zvt.factors import BullFactor
from ..context import init_test_context
init_test_context()
class TechnicalSelector(TargetSelector):
def init_factors(self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, level):
bull_factor = BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=level,
adjust_type="qfq",
)
self.factors = [bull_factor]
def test_cross_ma_selector():
entity_ids = ["stock_sz_000338"]
entity_type = "stock"
start_timestamp = "2018-01-01"
end_timestamp = "2019-06-30"
my_selector = TargetSelector(
entity_ids=entity_ids, entity_schema=entity_type, start_timestamp=start_timestamp, end_timestamp=end_timestamp
)
# add the factors
my_selector.add_factor(
CrossMaFactor(
entity_ids=entity_ids,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
computing_window=10,
windows=[5, 10],
need_persist=False,
level=IntervalLevel.LEVEL_1DAY,
adjust_type="hfq",
)
)
my_selector.run()
print(my_selector.open_long_df)
print(my_selector.open_short_df)
assert "stock_sz_000338" in my_selector.get_open_short_targets("2018-01-29")
def test_technical_selector():
selector = TechnicalSelector(
start_timestamp="2019-01-01", end_timestamp="2019-06-10", level=IntervalLevel.LEVEL_1DAY, provider="joinquant"
)
selector.run()
print(selector.get_result_df())
targets = selector.get_open_long_targets("2019-06-04")
assert "stock_sz_000338" not in targets
assert "stock_sz_000338" not in targets
assert "stock_sz_002572" not in targets
assert "stock_sz_002572" not in targets
targets = selector.get_open_short_targets("2019-06-04")
assert "stock_sz_000338" in targets
assert "stock_sz_000338" in targets
assert "stock_sz_002572" in targets
assert "stock_sz_002572" in targets
selector.move_on(timeout=0)
targets = selector.get_open_long_targets("2019-06-19")
assert "stock_sz_000338" in targets
assert "stock_sz_002572" not in targets

View File

@@ -16,7 +16,7 @@ def test_ma():
start_timestamp="2019-01-01",
end_timestamp="2019-06-10",
level=IntervalLevel.LEVEL_1DAY,
computing_window=30,
keep_window=30,
transformer=MaTransformer(windows=[5, 10, 30]),
adjust_type="qfq",
)
@@ -49,7 +49,7 @@ def test_macd():
start_timestamp="2019-01-01",
end_timestamp="2019-06-10",
level=IntervalLevel.LEVEL_1DAY,
computing_window=None,
keep_window=None,
transformer=MacdTransformer(),
adjust_type="qfq",
)