from __future__ import annotations

import contextlib
import datetime
import os
import time
from pathlib import Path
from typing import Any, Dict, List, Optional

from pydantic import validator
from redis import Redis
from redis_om import (
    Field,
    JsonModel,
    Migrator,
    NotFoundError as RedisNotFoundError,
    get_redis_connection,
)

from id_broker_common.utils import load_settings
from ucsschool.apis.utils import get_logger

from .sddb_client import (
    GroupType,
    NotFoundError,
    SdDBClient,
    SddbGroup,
    SddbInternalUser,
    SddbServiceProviderMapping,
    SddbUser,
)
from .settings import PLUGIN_SETTINGS_FILE

##############################################################################
#                                                                            #
# The `RedisUser`, `RedisGroup` and `RedisServiceProviderMapping` classes    #
# are copies from the SDDB code.                                             #
#                                                                            #
# Keep them in sync!                                                         #
#                                                                            #
##############################################################################


API_VERSION_GROUP = 1
API_VERSION_USER = 1
API_VERSION_SP_MAPPING = 1

logger = get_logger()


def get_redis_url(path: Path = PLUGIN_SETTINGS_FILE) -> str:
    with contextlib.suppress(KeyError):
        url = os.environ["REDIS_OM_URL"]
        logger.info("Connecting to Redis at %r (read from environment).", url)
        return url
    try:
        settings = load_settings(path)
        url = settings["redis_url"]
        logger.info("Connecting to Redis at %r (read from '%s').", url, path)
        return url
    except (EnvironmentError, KeyError, ValueError) as exc:
        msg = f"Missing, incomplete or malformed settings file '{path!s}': {exc!s}"
        logger.critical(msg)
        raise ValueError(msg) from exc


class RedisSdDBClient(SdDBClient):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.prefix = f"{RedisUser.Meta.global_key_prefix}:{RedisUser.Meta.model_key_prefix}:"
        redis_url = get_redis_url()
        self.redis_con = get_redis_connection(url=redis_url)
        logger.debug(f"Opened new connection to Redis at {redis_url!r}.")
        self.redis_con.ping()
        t0 = time.time()
        Migrator().run()
        logger.debug("Ran migrations in %.03f seconds.", time.time() - t0)

    def get_user(self, pseudonym: str, service_provider_id: int) -> SddbUser:
        # self.setup_om_redis()
        sp_attr = getattr(RedisUser, f"p{service_provider_id:04d}")
        try:
            redis_user: RedisUser = RedisUser.find(sp_attr == pseudonym).first()
        except RedisNotFoundError as exc:
            raise NotFoundError(f"Retrieving user {pseudonym=} {service_provider_id=}: {exc}") from exc
        return redis_user.to_sddb_user(service_provider_id=service_provider_id)

    def get_users(self, pks: List[str]) -> List[Optional[SddbInternalUser]]:
        if not pks:
            return []
        redis_om_pks = [f"{self.prefix}{pk}" for pk in pks]
        return self.redis_con.json().mget(redis_om_pks, ".")

    def get_group_by_pseudonym(self, pseudonym: str, service_provider_id: int) -> SddbGroup:
        sp_attr = getattr(RedisGroup, f"p{service_provider_id:04d}")
        try:
            redis_group: RedisGroup = RedisGroup.find(sp_attr == pseudonym).first()
        except RedisNotFoundError as exc:
            raise NotFoundError(f"Retrieving group {pseudonym=} {service_provider_id=}: {exc}") from exc
        return redis_group.to_sddb_group(service_provider_id)

    def get_group_by_id(self, group_id: str, service_provider_id: int) -> SddbGroup:
        try:
            redis_group: RedisGroup = RedisGroup.find(RedisGroup.id == group_id).first()
        except RedisNotFoundError as exc:
            raise NotFoundError(f"Retrieving group with ID {group_id!r}: {exc}") from exc
        return redis_group.to_sddb_group(service_provider_id=service_provider_id)

    def get_service_provider_mapping(self) -> SddbServiceProviderMapping:
        try:
            redis_spm: RedisServiceProviderMapping = RedisServiceProviderMapping.find().first()
        except RedisNotFoundError as exc:
            raise NotFoundError(f"Retrieving service provider mapping: {exc}") from exc
        return redis_spm.to_sddb_sp_mapping()


class RedisUser(JsonModel):
    """
    Representation of a user object in Redis using RedisJSON and RediSearch through the `redis_om`
    library.

    Keep in sync with the `RedisUser` class in the SDDB code!
    """

    id: str = Field(index=True)  # This is _not_ the PK that redis-om generates as key for Redis.
    api_version: int = Field(index=True)  # used for updating objects when attributes are added
    name: str  # UID
    school: str  # OU
    schools: List[str]  # OU1, OU2, ...
    school_authority: str
    modifyTimestamp: str  # LDAP: "20230428142640Z" Redis: "2023-04-28T14:26:40"
    p0001: Optional[str] = Field(index=True)  # This is idBrokerPseudonym0001 in LDAP.
    p0002: Optional[str] = Field(index=True)  # As the names of these keys will be 10 mio times in
    p0003: Optional[str] = Field(index=True)  # Redis, we're saving on RAM by shortening them.
    p0004: Optional[str] = Field(index=True)
    p0005: Optional[str] = Field(index=True)
    p0006: Optional[str] = Field(index=True)
    p0007: Optional[str] = Field(index=True)
    p0008: Optional[str] = Field(index=True)
    p0009: Optional[str] = Field(index=True)
    p0010: Optional[str] = Field(index=True)
    p0011: Optional[str] = Field(index=True)
    p0012: Optional[str] = Field(index=True)
    p0013: Optional[str] = Field(index=True)
    p0014: Optional[str] = Field(index=True)
    p0015: Optional[str] = Field(index=True)
    p0016: Optional[str] = Field(index=True)
    p0017: Optional[str] = Field(index=True)
    p0018: Optional[str] = Field(index=True)
    p0019: Optional[str] = Field(index=True)
    p0020: Optional[str] = Field(index=True)
    p0021: Optional[str] = Field(index=True)
    p0022: Optional[str] = Field(index=True)
    p0023: Optional[str] = Field(index=True)
    p0024: Optional[str] = Field(index=True)
    p0025: Optional[str] = Field(index=True)
    p0026: Optional[str] = Field(index=True)
    p0027: Optional[str] = Field(index=True)
    p0028: Optional[str] = Field(index=True)
    p0029: Optional[str] = Field(index=True)
    p0030: Optional[str] = Field(index=True)
    groups: Dict[str, List[str]]  # GroupType to list of Groups.id (entryUUID) mapping
    data: Dict[str, Any]  # complete DataSource object (Kelvin User) without pseudonyms

    _conn: Optional[Redis] = None

    class Meta:
        global_key_prefix = "sddb"
        model_key_prefix = "user"
        MAX_PSEUDONYMS = 30

    def __str__(self) -> str:
        return f"{self.__class__.__name__}(pk={self.pk!r}, id={self.id!r} name={self.name!r})"

    @classmethod
    def db(cls):
        if not cls._conn:
            redis_url = get_redis_url()
            cls._conn = get_redis_connection(url=redis_url)
            logger.debug(f"Opened new connection to Redis at {redis_url!r}.")
        cls._meta.database = cls._conn
        return super().db()

    @validator("modifyTimestamp")
    def correct_timestamp(cls, v):  # pragma: no cover
        try:
            datetime.datetime.fromisoformat(v)
        except ValueError as exc:
            raise ValueError(f"Expected ISO format (e.g. '2023-04-28T14:26:40'): {v!r}.") from exc
        return v

    def to_sddb_user(self, service_provider_id: int) -> SddbUser:
        return SddbUser(
            pseudonym=getattr(self, f"p{service_provider_id:04d}"),
            api_version=self.api_version,
            name=self.name,
            school=self.school,
            schools=self.schools,
            school_authority=self.school_authority,
            school_classes=self.groups["school_class"],
            workgroups=self.groups["workgroup"],
            modifyTimestamp=datetime.datetime.fromisoformat(self.modifyTimestamp),
            data=self.data,
        )


class RedisGroup(JsonModel):
    """
    Representation of a group object in Redis using RedisJSON and RediSearch through the `redis_om`
    library.

    Keep in sync with the `RedisGroup` class in the SDDB code!
    """

    id: str = Field(index=True)  # This is _not_ the PK that redis-om generates as key for Redis.
    api_version: int = Field(index=True)  # used for updating objects when attributes are added
    type: GroupType
    name: str  # CN
    school: str  # OU
    school_authority: str
    modifyTimestamp: str  # LDAP: "20230428142640Z" Redis: "2023-04-28T14:26:40"
    p0001: Optional[str] = Field(index=True)  # This is idBrokerPseudonym0001 in LDAP.
    p0002: Optional[str] = Field(index=True)  # As the names of these keys will be 10 mio times in
    p0003: Optional[str] = Field(index=True)  # Redis, we're saving on RAM by shortening them.
    p0004: Optional[str] = Field(index=True)
    p0005: Optional[str] = Field(index=True)
    p0006: Optional[str] = Field(index=True)
    p0007: Optional[str] = Field(index=True)
    p0008: Optional[str] = Field(index=True)
    p0009: Optional[str] = Field(index=True)
    p0010: Optional[str] = Field(index=True)
    p0011: Optional[str] = Field(index=True)
    p0012: Optional[str] = Field(index=True)
    p0013: Optional[str] = Field(index=True)
    p0014: Optional[str] = Field(index=True)
    p0015: Optional[str] = Field(index=True)
    p0016: Optional[str] = Field(index=True)
    p0017: Optional[str] = Field(index=True)
    p0018: Optional[str] = Field(index=True)
    p0019: Optional[str] = Field(index=True)
    p0020: Optional[str] = Field(index=True)
    p0021: Optional[str] = Field(index=True)
    p0022: Optional[str] = Field(index=True)
    p0023: Optional[str] = Field(index=True)
    p0024: Optional[str] = Field(index=True)
    p0025: Optional[str] = Field(index=True)
    p0026: Optional[str] = Field(index=True)
    p0027: Optional[str] = Field(index=True)
    p0028: Optional[str] = Field(index=True)
    p0029: Optional[str] = Field(index=True)
    p0030: Optional[str] = Field(index=True)
    students: List[str]  # Redis PKs, use MGET to retrieve all at once
    teachers: List[str]  # Redis PKs, use MGET to retrieve all at once
    data: Dict[str, Any]  # complete DataSource object (Kelvin SchoolClass) without pseudonyms

    _conn: Optional[Redis] = None

    class Meta:
        global_key_prefix = "sddb"
        model_key_prefix = "group"
        MAX_PSEUDONYMS = 30

    def __str__(self) -> str:
        return f"{self.__class__.__name__}(pk={self.pk!r}, id={self.id!r} name={self.name!r})"

    @classmethod
    def db(cls):
        if not cls._conn:
            redis_url = get_redis_url()
            cls._conn = get_redis_connection(url=redis_url)
            logger.debug(f"Opened new connection to Redis at {redis_url!r}.")
        cls._meta.database = cls._conn
        return super().db()

    @validator("modifyTimestamp")
    def correct_timestamp(cls, v):  # pragma: no cover
        try:
            datetime.datetime.fromisoformat(v)
        except ValueError as exc:
            raise ValueError(f"Expected ISO format (e.g. '2023-04-28T14:26:40'): {v!r}.") from exc
        return v

    def to_sddb_group(self, service_provider_id: int) -> SddbGroup:
        return SddbGroup(
            pseudonym=getattr(self, f"p{service_provider_id:04d}"),
            api_version=self.api_version,
            type=self.type,
            name=self.name,
            school=self.school,
            school_authority=self.school_authority,
            modifyTimestamp=datetime.datetime.fromisoformat(self.modifyTimestamp),
            students=self.students,
            teachers=self.teachers,
            data=self.data,
        )


class RedisServiceProviderMapping(JsonModel):
    """
    Representation of a service provider mapping object in Redis using RedisJSON and RediSearch through
    the `redis_om` library.

    Keep in sync with the `RedisServiceProviderMapping` class in the Self-disclosure API code!
    """

    id: str = Field(index=True)  # This is _not_ the PK that redis-om generates as key for Redis.
    api_version: int = Field(index=True)  # used for updating objects when attributes are added
    name: str  # CN
    mapping: Dict[str, str]  # SP to LDAP attribute name mapping

    _conn: Optional[Redis] = None

    class Meta:
        global_key_prefix = "sddb"
        model_key_prefix = "sp_mapping"

    def __str__(self) -> str:
        return f"{self.__class__.__name__}(pk={self.pk!r}, id={self.id!r} name={self.name!r})"

    @classmethod
    def db(cls):
        if not cls._conn:
            redis_url = get_redis_url()
            cls._conn = get_redis_connection(url=redis_url)
            logger.debug(f"Opened new connection to Redis at {redis_url!r}.")
        cls._meta.database = cls._conn
        return super().db()

    def to_sddb_sp_mapping(self) -> SddbServiceProviderMapping:
        return SddbServiceProviderMapping(
            api_version=self.api_version,
            name=self.name,
            mapping=self.mapping,
        )
