import json
import os
import random
import time
from typing import Dict, List, Optional, Union, cast

import requests
from authlib.integrations.requests_client import OAuth2Auth
from locust import HttpUser
from locust.clients import ResponseContextManager as Response
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_fixed
from testapp_client import get_access_token_and_pseudonym, refresh_token

from performance_tests.common import TestData

# disable deprecation warning
# Certificate for login.kc1.broker.intranet has no `subjectAltName`,
# falling back to check for a `commonName` for now.
requests.packages.urllib3.disable_warnings()

SCHOOL_AUTHORITIES_TRAEGER_MAPPING = {
    "Traeger1": os.environ["UCS_ENV_TRAEGER1_IDP_HINT"],
    "Traeger2": os.environ["UCS_ENV_TRAEGER2_IDP_HINT"],
}

SD_METADATA = "/ucsschool/apis/self_disclosure/v1/users/{id}/metadata"
SD_GROUPS = "/ucsschool/apis/self_disclosure/v1/users/{id}/groups"
SD_USERS = "/ucsschool/apis/self_disclosure/v1/groups/{id}/users"

UCS_ENV_TEST_APP_FQDN = os.environ.get("UCS_ENV_TEST_APP_FQDN")
LOCUST_USERS = os.environ.get("LOCUST_USERS", "2")

if (
    not UCS_ENV_TEST_APP_FQDN
    or not SCHOOL_AUTHORITIES_TRAEGER_MAPPING["Traeger1"]
    or not SCHOOL_AUTHORITIES_TRAEGER_MAPPING["Traeger2"]
):
    raise ValueError(
        "ENV not set, check: UCS_ENV_TEST_APP_FQDN, UCS_ENV_TRAEGER1_IDP_HINT and UCS_ENV_TRAEGER2_IDP_HINT"
    )


class ResponseError(Exception):
    ...


class SDMetadata(BaseModel):
    user_id: str
    username: str
    firstname: str
    lastname: str
    type: str

    school_authority: str
    school_id: str

    def __str__(self):
        data = self.__dict__
        del data["firstname"]
        del data["lastname"]
        return f"{self.__class__.__name__}{data}"


class SDUserGroups(BaseModel):
    groups: List[Dict[str, Union[str, int]]]

    def __str__(self):
        return f"{self.__class__.__name__}{self.groups}"


class SDGroupUsers(BaseModel):
    students: List[Dict[str, str]]
    teachers: List[Dict[str, str]]

    def __str__(self):
        str_students = self.students.copy()
        for student in str_students:
            del student["firstname"]
            del student["lastname"]
        return f"{self.__class__.__name__}(students: {str_students}, teachers: {self.teachers})"


class SelfDisclosureClient(HttpUser):
    abstract = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.username = None
        self.password = None
        self.idp_hint = None
        self._token = None
        self.pseudonym = None
        self._auth: Optional[OAuth2Auth] = None

    @property
    def token(self):
        return self._token

    @token.setter
    def token(self, token):
        self._token = token
        self._auth = OAuth2Auth(token, token_placement="header")

    @property
    def auth(self):
        if (self.token["expires_at"] - time.time()) < 30:
            token = refresh_token(self.token)
            self.token = token
        return self._auth

    def _check_get_data(self, response: Response):
        msg: str = f"status code: {response.status_code} Response text: {response.text}"
        if response.status_code == 401:
            response.success()
            token = refresh_token(self.token)
            self.token = token
            raise ResponseError(msg, " -> Token Refresh")
        elif response.status_code != 200:
            response.failure(msg)
            raise ResponseError(msg)
        try:
            data = json.loads(response.text)
        except json.JSONDecodeError as exc:
            raise ResponseError(msg) from exc
        return data

    def get_user_metadata(self):
        url: str = f"https://{self.host}{SD_METADATA.format(id=self.pseudonym)}"
        with self.client.get(
            url, name=SD_METADATA, auth=self.auth, verify=False, catch_response=True
        ) as response:
            response = cast(Response, response)
            metadata = self._check_get_data(response)
        metadata.setdefault("school_authority", self.idp_hint)
        return SDMetadata(**metadata)

    def get_groups_of_user(self):
        url: str = f"https://{self.host}{SD_GROUPS.format(id=self.pseudonym)}"
        with self.client.get(
            url, name=SD_GROUPS, auth=self.auth, verify=False, catch_response=True
        ) as response:
            response = cast(Response, response)
            groups = self._check_get_data(response)
        return SDUserGroups(groups=groups["groups"])

    def get_users_of_group(self, id):
        url: str = f"https://{self.host}{SD_USERS.format(id=id)}"
        with self.client.get(
            url, name=SD_USERS, auth=self.auth, verify=False, catch_response=True
        ) as response:
            response = cast(Response, response)
            users = self._check_get_data(response)
        return SDGroupUsers(students=users["students"], teachers=users["teachers"])


@retry(stop=stop_after_attempt(10), wait=wait_fixed(12))
def get_student_data(get_token=True) -> List[Dict[str, str]]:
    return _get_user_data("student", get_token=get_token)


@retry(stop=stop_after_attempt(10), wait=wait_fixed(12))
def get_teacher_data(get_groups=False, get_token=True) -> List[Dict[str, str]]:
    return _get_user_data("teacher", get_groups=get_groups, get_token=get_token)


def _get_user_data(get_role: str, get_groups=False, get_token=True) -> List[Dict[str, str]]:
    locust_users = int(LOCUST_USERS)
    numbers = {
        5: (2, 3, 1),
        10: (2, 3, 2),
        15: (2, 4, 2),
        20: (2, 5, 2),
    }

    sa_no = 0  # number of school authorities to load
    sch_no = 0  # number of schools to load per school authority
    user_no = 0  # number of users to load per school

    k = ((locust_users // 10 * 10) + ((locust_users % 10 > 0) * 10)) // 10

    if locust_users <= 20:
        for i in range(5, 21, 5):
            if locust_users <= i:
                sa_no, sch_no, user_no = numbers[i]
                break
    elif locust_users <= 200:
        sa_no, sch_no, user_no = (2, 5, k)
    else:
        sa_no, sch_no, user_no = (2, 10, (k // 2 + (k % 2 > 0)))

    td = TestData()
    data: List[Dict[str, str]] = []
    users = []
    for sa in random.choices(td.school_authorities, k=sa_no):
        for school in random.choices(list(td.school_authority(sa).schools.keys()), k=sch_no):
            school_data = td.school(sa, school)
            students = []
            teachers = []
            for user in school_data.users:
                if "student" in school_data.users[user].ucsschoolRole:
                    students.append((school_data.users[user], school, sa, school_data.groups))
                elif "teacher" in school_data.users[user].ucsschoolRole:
                    teachers.append((school_data.users[user], school, sa, school_data.groups))

            if get_role == "student":
                users += random.choices(students, k=user_no)
            elif get_role == "teacher":
                users += random.choices(teachers, k=user_no)

    if get_token:
        print(f"\nRetrieve tokens for {len(users)} users:")

    t_start = time.time()

    for user_data in users:
        user = user_data[0]
        school = user_data[1]
        sa = user_data[2]
        groups = user_data[3]
        user_groups = []
        if get_groups:
            user_groups.extend(
                groups[group].pseudonyms["idBrokerPseudonym0001"]
                for group in groups
                if user.name in groups[group].memberUid
            )

        username = f"{school}-{user.name}"
        password = user.password
        role = user.ucsschoolRole
        idp_hint = SCHOOL_AUTHORITIES_TRAEGER_MAPPING[sa]

        if get_token:
            print(f"  user: {user.name}, sa: {sa}, school: {school}, role: {role}")
            token, pseudonym = get_access_token_and_pseudonym(username, password, idp_hint)
        else:
            token = ""
            pseudonym = ""

        entry = {
            "username": username,
            "password": password,
            "role": role,
            "idp_hint": idp_hint,
            "token": token,
            "pseudonym": pseudonym,
            "groups": user_groups,
        }
        data.append(entry)

    if get_token:
        print(f"time elapsed: {round((time.time() - t_start), 2)} seconds")

    return data
