#-*- coding: utf-8 -*-

import time
import requests

class staticConfig:
    host = "https://api.worldquantbrain.com"
    login = "/authentication"
    simulation = "/simulations"
    get_data_fields = "/data-fields"
    get_user_alpha_list = '/users/self/alphas'
    get_alpha_detail = '/alphas/{}'
    add_tag = '/tags'
    
    simulate_type = 'REGULAR'
    simulate_settings = {
            'instrumentType': 'EQUITY',
            'region': 'USA',
            'universe': 'TOP3000',
            'delay': 1,
            'decay': 0,
            'neutralization': 'INDUSTRY',
            'truncation': 0.08,
            'pasteurization': 'ON',
            'unitHandling': 'VERIFY',
            'nanHandling': 'OFF',
            'language': 'FASTEXPR',
            'visualization': False,
        }

class BrainClient:
    def __init__(self, username: str, password: str):
        self.username = username
        self.password = password
        self.conf = staticConfig()
        self.sess = self.login()

    def login(self) -> requests.Session: 
        sess = requests.Session()
        sess.auth = (self.username, self.password)
        login_url = self.conf.host + self.conf.login
        response = sess.post(login_url)
        if response.status_code != 201:
            print("login response: ", response.text)
            raise Exception("Login failed")
        return sess
    
    def post(self, url: str, body: dict = None) -> requests.Response:
        resp = self.sess.post(url, json=body)
        if resp.status_code == 401:
            self.sess = self.login()
            resp = self.sess.post(url, json=body)
        elif resp.status_code >= 400:
            error_message = f'Post {url} failed, status_code = {resp.status_code}, response_body = {resp.text}'
            raise Exception(error_message)
        return resp
    
    def get(self, url: str) -> requests.Response:
        resp = self.sess.get(url)
        if resp.status_code == 401:
            self.sess = self.login()
            resp = self.sess.get(url)
        elif resp.status_code >= 400:
            error_message = f'Get {url} failed, status_code = {resp.status_code}, response_body = {resp.text}'
            raise Exception(error_message)
        return resp
    
    def patch(self, url: str, body: dict = None) -> requests.Response:
        resp = self.sess.patch(url, json=body)
        if resp.status_code == 401:
            self.sess = self.login()
            resp = self.sess.patch(url, json=body)
        elif resp.status_code >= 400:
            error_message = f'Patch {url} failed, status_code = {resp.status_code}, response_body = {resp.text}'
            raise Exception(error_message)
        return resp
    # {'id': '3bHSZMfL74VKcq0S8AfLmxT', 'type': 'REGULAR', 'settings': {'instrumentType': 'EQUITY', 'region': 'USA', 'universe': 'TOP3000', 'delay': 1, 'decay': 0, 'neutralization': 'INDUSTRY', 'truncation': 0.08, 'pasteurization': 'ON', 'unitHandling': 'VERIFY', 'nanHandling': 'OFF', 'language': 'FASTEXPR', 'visualization': False}, 'regular': 'liabilities/assets', 'status': 'COMPLETE', 'alpha': 'ljwPdZl'}
    def simulate_fastexpr(self, regular: str) -> dict:
        body = {
            'type': self.conf.simulate_type,
            'settings': self.conf.simulate_settings,
            'regular': regular
        }

        simulate_url = self.conf.host + self.conf.simulation
        response = self.post(simulate_url, body)
        if response.status_code != 201:
            print("simulate response: ", response.text)
            raise Exception("Simulate failed")
        
        location = response.headers['Location']

        while True:
            r = self.get(location)
            retry_after_sec = float(r.headers.get("Retry-After", 0))
            if retry_after_sec == 0:
                break
            time.sleep(retry_after_sec)

        return r.json()
    
    def get_data_fields(self, dataset_id: str, offset: int = 0, limit: int = 50):
        instrument_type = self.conf.simulate_settings['instrumentType']
        region = self.conf.simulate_settings['region']
        universe = self.conf.simulate_settings['universe']
        delay = str(self.conf.simulate_settings['delay'])
        
        base_url = f'{self.conf.host}{self.conf.get_data_fields}?instrumentType={instrument_type}&region={region}&universe={universe}&delay={delay}&dataset.id={dataset_id}&offset={offset}&limit={limit}'
        return self.get(base_url).json()
    
    def get_alpha_detail(self, alpha_id: str):
        path = self.conf.get_alpha_detail.format(alpha_id)
        base_url = f'{self.conf.host}{path}'
        return self.get(base_url).json()
    
    def add_tag(self, tag_name: str, *alphas):
        body = {
            'alphas': alphas,
            'name': tag_name,
            'type': 'LIST'
        }
        base_url = f'{self.conf.host}{self.conf.add_tag}'
        return self.post(base_url, body).json()
    
    def set_favorite(self, alpha_id: str, favorite: bool = True):
        body = {
            'favorite': favorite
        }
        path = self.conf.get_alpha_detail.format(alpha_id)
        base_url = f'{self.conf.host}{path}'
        return self.patch(base_url, body).json()
    
    def get_user_unsubmited_alphas(self, offset = 0, limit = 50):
        base_url = f'{self.conf.host}{self.conf.get_user_alpha_list}?status=UNSUBMITTED&hidden=false&offset={offset}&limit={limit}'
        return self.get(base_url).json()