from __future__ import annotations

import asyncio
import datetime
import logging
import time
from functools import lru_cache
from typing import Any, Dict, List, Sequence

from async_lru import alru_cache
from ldap3 import AUTO_BIND_TLS_BEFORE_BIND, BASE, RESTARTABLE, SIMPLE, SUBTREE, Connection, Entry
from ldap3.core.exceptions import LDAPBindError, LDAPExceptionError
from ldap3.utils.conv import escape_filter_chars

from id_broker_common.pseudonyms_udm_rest import SP_MAPPINGS_DN, get_settings_data_content
from ucsschool.apis.utils import LDAPAccess, LDAPCredentials, LDAPSettings, get_logger

from .models import RawUser
from .utils import ldap_credentials, ldap_settings

logger: logging.Logger = get_logger()
CONCATINATED_USERS_SEARCH_LIMIT = 50


# TODO: modify LDAPAccess in the ucsschool-apis app to allow choosing bind credentials
class LDAPAccessWithCnAdminSearch(LDAPAccess):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ldap3_conn_primary = Connection(
            self.server_master,
            user=self.credentials.cn_admin_dn,
            password=self.credentials.cn_admin_password,
            auto_bind=AUTO_BIND_TLS_BEFORE_BIND,
            authentication=SIMPLE,
            read_only=True,
            client_strategy=RESTARTABLE,
        )
        self.ldap3_conn_host = Connection(
            self.server_host,
            user=self.credentials.cn_admin_dn,
            password=self.credentials.cn_admin_password,
            auto_bind=AUTO_BIND_TLS_BEFORE_BIND,
            authentication=SIMPLE,
            read_only=True,
            client_strategy=RESTARTABLE,
        )

    async def search(
        self,
        search_filter: str,
        search_base: str = None,
        search_scope: str = SUBTREE,
        attributes: List[str] = None,
        use_master: bool = False,
        raise_on_bind_error: bool = True,
    ) -> List[Entry]:
        """
        Function to search the LDAP for arbitrary entries.

        :param search_filter: The LDAP search filter
        :param search_base: The LDAP search base
        :param search_scope: The LDAP search scope
        :param attributes: The attributes to return in the entries
        :param use_master: If True the connection is made to the master of the domain
        :param raise_on_bind_error: If True an exception is raised on failed bind, else an empty list is returned
        :return: The list of entries corresponding to the search parameters
        """
        search_base = search_base or self.settings.ldap_base
        conn = self.ldap3_conn_primary if use_master else self.ldap3_conn_host
        try:
            conn.search(
                search_base=search_base,
                search_filter=search_filter,
                search_scope=search_scope,
                attributes=attributes,
            )
        except LDAPExceptionError as exc:
            if isinstance(exc, LDAPBindError) and not raise_on_bind_error:
                return []
            self.logger.exception(
                "When connecting to %r with bind_dn %r: %s",
                self.settings.host_fqdn,
                self.settings.host_dn,
                exc,
            )
            raise
        return conn.entries


@lru_cache(maxsize=1)
def ldap_access(settings: LDAPSettings = None, credentials: LDAPCredentials = None) -> LDAPAccess:
    """Cached LDAP access object"""
    return LDAPAccessWithCnAdminSearch(settings or ldap_settings(), credentials or ldap_credentials())


@alru_cache(maxsize=1)
async def get_synonym_attributes(cache_bust: Any) -> List[str]:
    """
    Cached loading of synonym attributes from LDAP.
    Refresh by changing `cache_bust`.
    """
    ldap_base = ldap_settings().ldap_base
    dn = SP_MAPPINGS_DN.format(ldap_base=ldap_base)
    data = await get_settings_data_content(dn)
    return list(data.values())


@alru_cache(maxsize=1)
async def get_synonym_attributes_dict(cache_bust: Any) -> Dict[str, str]:
    """
    Cached loading of synonym attributes from LDAP.
    Refresh by changing `cache_bust`.
    """
    ldap_base = ldap_settings().ldap_base
    dn = SP_MAPPINGS_DN.format(ldap_base=ldap_base)
    data = await get_settings_data_content(dn)
    return data


async def get_users(usernames: Sequence[str]) -> List[RawUser]:
    """Retrieve data of multiple users."""
    if not usernames:
        return []
    if len(usernames) > CONCATINATED_USERS_SEARCH_LIMIT:
        # recursion
        res = await get_users(usernames[:CONCATINATED_USERS_SEARCH_LIMIT])
        res.extend(await get_users(usernames[CONCATINATED_USERS_SEARCH_LIMIT:]))
        return res
    ldap = ldap_access()
    filter_uids = [f"(uid={escape_filter_chars(uid)})" for uid in usernames]
    filter_s = f"(&(univentionObjectType=users/user)(|{''.join(filter_uids)}))"
    cache_bust = int(datetime.datetime.now().strftime("%Y%m%d%H%M")) // 5  # refresh cache every 5 min
    attrs = [
        "givenName",
        "sn",
        "ucsschoolRole",
        "ucsschoolSchool",
        "ucsschoolRecordUID",
        "ucsschoolSourceUID",
        "uid",
    ]
    attrs.extend((await get_synonym_attributes_dict(cache_bust)).values())
    res = []
    for result in await ldap.search(filter_s, attributes=attrs):
        ucsschool_roles = result.ucsschoolRole.value
        if isinstance(ucsschool_roles, str):
            ucsschool_roles = [ucsschool_roles]
        schools = result.ucsschoolSchool.value
        if isinstance(schools, str):
            schools = [schools]
        res.append(
            RawUser(
                dn=result.entry_dn,
                firstname=result.givenName.value,
                lastname=result.sn.value,
                record_uid=result.ucsschoolRecordUID.value,
                schools=schools,
                source_uid=result.ucsschoolSourceUID.value,
                synonyms={
                    k: v[0]
                    for k, v in result.entry_attributes_as_dict.items()
                    if k.startswith("idBrokerPseudonym00")
                },
                ucsschool_roles=ucsschool_roles,
                username=result.uid.value,
            )
        )
    return res


async def wait_for_replication(
    ldap: LDAPAccess, dn: str, should_exist: bool = True, ttl: float = 300.0, interval: float = 1.0
):
    """
    Wait for the (non-)existence of an LDAP object.

    :param LDAPAccess ldap: LDAPAccess object to use.
    :param str dn: The DN of the object to wait for.
    :param bool should_exist: Whether to wait for its existence or absence.
    :param float ttl: Seconds after which to give up.
    :param float interval: Time to wait between retries.
    :return: None
    :raise TimeoutError: When `ttl` seconds expired without the object existing/disappearing in LDAP.
    """
    start = time.time()

    while time.time() < start + ttl:
        results = await ldap.search(
            search_filter="(objectClass=*)",
            search_base=dn,
            search_scope=BASE,
            attributes=None,
            use_master=False,
            raise_on_bind_error=True,
        )
        if len(results) == 1 and should_exist:
            logger.debug("%r exists in local LDAP after %.3f sec.", dn, time.time() - start)
            return
        elif not results and not should_exist:
            logger.debug("%r absent in local LDAP after %.3f sec.", dn, time.time() - start)
            return
        await asyncio.sleep(interval)

    exist_txt = "absent" if should_exist else "exists"
    raise TimeoutError(f"{dn!r} {exist_txt} in local LDAP after {time.time() - start:.3f} sec.")
