"""
This module has redis storage for finite-state machine based on `redis <https://pypi.org/project/redis/>`_ driver.
"""

import asyncio
import logging
import typing

from ...dispatcher.storage import BaseStorage
from ...utils import json
from ...utils.deprecated import deprecated

if typing.TYPE_CHECKING:
    import aioredis

STATE_KEY = 'state'
STATE_DATA_KEY = 'data'
STATE_BUCKET_KEY = 'bucket'


class RedisStorage(BaseStorage):
    """
    Simple Redis-base storage for FSM.

    Usage:

    .. code-block:: python3

        storage = RedisStorage('localhost', 6379, db=5)
        dp = Dispatcher(bot, storage=storage)

    And need to close Redis connection when shutdown

    .. code-block:: python3

        await dp.storage.close()
        await dp.storage.wait_closed()

    """

    @deprecated("`RedisStorage` will be removed in aiogram v3.0. "
                "Use `RedisStorage2` instead.", stacklevel=3)
    def __init__(self, host='localhost', port=6379, db=None, password=None, ssl=None, loop=None, **kwargs):
        self._host = host
        self._port = port
        self._db = db
        self._password = password
        self._ssl = ssl
        self._kwargs = kwargs

        self._redis: typing.Optional["aioredis.RedisConnection"] = None
        self._connection_lock = asyncio.Lock()

    async def close(self):
        async with self._connection_lock:
            if self._redis and not self._redis.closed:
                self._redis.close()

    async def wait_closed(self):
        async with self._connection_lock:
            if self._redis:
                return await self._redis.wait_closed()
            return True

    async def redis(self) -> "aioredis.RedisConnection":
        """
        Get Redis connection
        """
        # Use thread-safe asyncio Lock because this method without that is not safe
        import aioredis

        async with self._connection_lock:
            if self._redis is None or self._redis.closed:
                self._redis = await aioredis.create_connection((self._host, self._port),
                                                               db=self._db, password=self._password, ssl=self._ssl,
                                                               **self._kwargs)
        return self._redis

    async def get_record(self, *,
                         chat: typing.Union[str, int, None] = None,
                         user: typing.Union[str, int, None] = None) -> typing.Dict:
        """
        Get record from storage

        :param chat:
        :param user:
        :return:
        """
        chat, user = self.check_address(chat=chat, user=user)
        addr = f"fsm:{chat}:{user}"

        conn = await self.redis()
        data = await conn.execute('GET', addr)
        if data is None:
            return {'state': None, 'data': {}}
        return json.loads(data)

    async def set_record(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                         state=None, data=None, bucket=None):
        """
        Write record to storage

        :param bucket:
        :param chat:
        :param user:
        :param state:
        :param data:
        :return:
        """
        if data is None:
            data = {}
        if bucket is None:
            bucket = {}

        chat, user = self.check_address(chat=chat, user=user)
        addr = f"fsm:{chat}:{user}"

        conn = await self.redis()
        if state is None and data == bucket == {}:
            await conn.execute('DEL', addr)
        else:
            record = {'state': state, 'data': data, 'bucket': bucket}
            await conn.execute('SET', addr, json.dumps(record))

    async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                        default: typing.Optional[str] = None) -> typing.Optional[str]:
        record = await self.get_record(chat=chat, user=user)
        return record.get('state', self.resolve_state(default))

    async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                       default: typing.Optional[str] = None) -> typing.Dict:
        record = await self.get_record(chat=chat, user=user)
        return record['data']

    async def set_state(self, *,
                        chat: typing.Union[str, int, None] = None,
                        user: typing.Union[str, int, None] = None,
                        state: typing.Optional[typing.AnyStr] = None):
        record = await self.get_record(chat=chat, user=user)
        state = self.resolve_state(state)
        await self.set_record(chat=chat, user=user, state=state, data=record['data'])

    async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                       data: typing.Dict = None):
        record = await self.get_record(chat=chat, user=user)
        await self.set_record(chat=chat, user=user, state=record['state'], data=data)

    async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                          data: typing.Dict = None, **kwargs):
        if data is None:
            data = {}
        record = await self.get_record(chat=chat, user=user)
        record_data = record.get('data', {})
        record_data.update(data, **kwargs)
        await self.set_record(chat=chat, user=user, state=record['state'], data=record_data)

    async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
        """
        Get list of all stored chat's and user's

        :return: list of tuples where first element is chat id and second is user id
        """
        conn = await self.redis()
        result = []

        keys = await conn.execute('KEYS', 'fsm:*')
        for item in keys:
            *_, chat, user = item.decode('utf-8').split(':')
            result.append((chat, user))

        return result

    async def reset_all(self, full=True):
        """
        Reset states in DB

        :param full: clean DB or clean only states
        :return:
        """
        conn = await self.redis()

        if full:
            await conn.execute('FLUSHDB')
        else:
            keys = await conn.execute('KEYS', 'fsm:*')
            await conn.execute('DEL', *keys)

    def has_bucket(self):
        return True

    async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                         default: typing.Optional[str] = None) -> typing.Dict:
        record = await self.get_record(chat=chat, user=user)
        return record.get('bucket', {})

    async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                         bucket: typing.Dict = None):
        record = await self.get_record(chat=chat, user=user)
        await self.set_record(chat=chat, user=user, state=record['state'], data=record['data'], bucket=bucket)

    async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
                            user: typing.Union[str, int, None] = None,
                            bucket: typing.Dict = None, **kwargs):
        record = await self.get_record(chat=chat, user=user)
        record_bucket = record.get('bucket', {})
        if bucket is None:
            bucket = {}
        record_bucket.update(bucket, **kwargs)
        await self.set_record(chat=chat, user=user, state=record['state'], data=record_bucket, bucket=bucket)


class RedisStorage2(BaseStorage):
    """
    Busted Redis-base storage for FSM.
    Works with Redis connection pool and customizable keys prefix.

    Usage:

    .. code-block:: python3

        storage = RedisStorage2('localhost', 6379, db=5, pool_size=10, prefix='my_fsm_key')
        dp = Dispatcher(bot, storage=storage)

    And need to close Redis connection when shutdown

    .. code-block:: python3

        await dp.storage.close()

    """

    def __init__(
            self,
            host: str = "localhost",
            port: int = 6379,
            db: typing.Optional[int] = None,
            password: typing.Optional[str] = None,
            ssl: typing.Optional[bool] = None,
            pool_size: int = 10,
            loop: typing.Optional[asyncio.AbstractEventLoop] = None,
            prefix: str = "fsm",
            state_ttl: typing.Optional[int] = None,
            data_ttl: typing.Optional[int] = None,
            bucket_ttl: typing.Optional[int] = None,
            **kwargs,
    ):
        from redis.asyncio import Redis

        self._redis: typing.Optional[Redis] = Redis(
            host=host,
            port=port,
            db=db,
            password=password,
            ssl=ssl,
            max_connections=pool_size,
            decode_responses=True,
            **kwargs,
        )

        self._prefix = (prefix,)
        self._state_ttl = state_ttl
        self._data_ttl = data_ttl
        self._bucket_ttl = bucket_ttl

    @deprecated("This method will be removed in aiogram v3.0. "
                "You should use your own instance of Redis.", stacklevel=3)
    async def redis(self) -> "aioredis.Redis":
        return self._redis

    def generate_key(self, *parts):
        return ':'.join(self._prefix + tuple(map(str, parts)))

    async def close(self):
        await self._redis.close()

    async def wait_closed(self):
        pass

    async def get_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                        default: typing.Optional[str] = None) -> typing.Optional[str]:
        chat, user = self.check_address(chat=chat, user=user)
        key = self.generate_key(chat, user, STATE_KEY)
        return await self._redis.get(key) or self.resolve_state(default)

    async def get_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                       default: typing.Optional[dict] = None) -> typing.Dict:
        chat, user = self.check_address(chat=chat, user=user)
        key = self.generate_key(chat, user, STATE_DATA_KEY)
        raw_result = await self._redis.get(key)
        if raw_result:
            return json.loads(raw_result)
        return default or {}

    async def set_state(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                        state: typing.Optional[typing.AnyStr] = None):
        chat, user = self.check_address(chat=chat, user=user)
        key = self.generate_key(chat, user, STATE_KEY)
        if state is None:
            await self._redis.delete(key)
        else:
            await self._redis.set(key, self.resolve_state(state), ex=self._state_ttl)

    async def set_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                       data: typing.Dict = None):
        chat, user = self.check_address(chat=chat, user=user)
        key = self.generate_key(chat, user, STATE_DATA_KEY)
        if data:
            await self._redis.set(key, json.dumps(data), ex=self._data_ttl)
        else:
            await self._redis.delete(key)

    async def update_data(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                          data: typing.Dict = None, **kwargs):
        if data is None:
            data = {}
        temp_data = await self.get_data(chat=chat, user=user, default={})
        temp_data.update(data, **kwargs)
        await self.set_data(chat=chat, user=user, data=temp_data)

    def has_bucket(self):
        return True

    async def get_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                         default: typing.Optional[dict] = None) -> typing.Dict:
        chat, user = self.check_address(chat=chat, user=user)
        key = self.generate_key(chat, user, STATE_BUCKET_KEY)
        raw_result = await self._redis.get(key)
        if raw_result:
            return json.loads(raw_result)
        return default or {}

    async def set_bucket(self, *, chat: typing.Union[str, int, None] = None, user: typing.Union[str, int, None] = None,
                         bucket: typing.Dict = None):
        chat, user = self.check_address(chat=chat, user=user)
        key = self.generate_key(chat, user, STATE_BUCKET_KEY)
        if bucket:
            await self._redis.set(key, json.dumps(bucket), ex=self._bucket_ttl)
        else:
            await self._redis.delete(key)

    async def update_bucket(self, *, chat: typing.Union[str, int, None] = None,
                            user: typing.Union[str, int, None] = None,
                            bucket: typing.Dict = None, **kwargs):
        if bucket is None:
            bucket = {}
        temp_bucket = await self.get_bucket(chat=chat, user=user)
        temp_bucket.update(bucket, **kwargs)
        await self.set_bucket(chat=chat, user=user, bucket=temp_bucket)

    async def reset_all(self, full=True):
        """
        Reset states in DB

        :param full: clean DB or clean only states
        :return:
        """
        if full:
            await self._redis.flushdb()
        else:
            keys = await self._redis.keys(self.generate_key('*'))
            await self._redis.delete(*keys)

    async def get_states_list(self) -> typing.List[typing.Tuple[str, str]]:
        """
        Get list of all stored chat's and user's

        :return: list of tuples where first element is chat id and second is user id
        """
        result = []

        keys = await self._redis.keys(self.generate_key('*', '*', STATE_KEY))
        for item in keys:
            *_, chat, user, _ = item.split(':')
            result.append((chat, user))

        return result

    async def import_redis1(self, redis1):
        await migrate_redis1_to_redis2(redis1, self)


async def migrate_redis1_to_redis2(storage1: RedisStorage, storage2: RedisStorage2):
    """
    Helper for migrating from RedisStorage to RedisStorage2

    :param storage1: instance of RedisStorage
    :param storage2: instance of RedisStorage2
    :return:
    """
    if not isinstance(storage1, RedisStorage):  # better than assertion
        raise TypeError(f"{type(storage1)} is not RedisStorage instance.")
    if not isinstance(storage2, RedisStorage):
        raise TypeError(f"{type(storage2)} is not RedisStorage instance.")

    log = logging.getLogger('aiogram.RedisStorage')

    for chat, user in await storage1.get_states_list():
        state = await storage1.get_state(chat=chat, user=user)
        await storage2.set_state(chat=chat, user=user, state=state)

        data = await storage1.get_data(chat=chat, user=user)
        await storage2.set_data(chat=chat, user=user, data=data)

        bucket = await storage1.get_bucket(chat=chat, user=user)
        await storage2.set_bucket(chat=chat, user=user, bucket=bucket)

        log.info(f"Migrated user {user} in chat {chat}")
