import sys
from typing import Dict, List, Tuple, TypeVar

from ucsschool.lib.roles import get_role_info

StrBytes = TypeVar("StrBytes", str, bytes)


def is_school_user(roles):  # type: (List[str]) -> bool
    for role in roles:
        r, _, _ = get_role_info(role)
        if r in ["student", "staff", "school_admin", "teacher", "legal_guardian"]:
            return True
    return False


def is_school_group(roles):  # type: (List[str]) -> bool
    for role in roles:
        r, _, _ = get_role_info(role)
        if r in ["school_class", "workgroup"]:
            return True
    return False


def is_school(roles):  # type: (List[str]) -> bool
    for role in roles:
        r, _, _ = get_role_info(role)
        if r == "school":
            return True
    return False


def is_school_or_school_group(roles):  # type: (List[str]) -> bool
    return is_school_group(roles) or is_school(roles)


def _replace_string(obj, old, new):  # type: (StrBytes, str, str) -> StrBytes
    if isinstance(obj, bytes):
        old = old.encode()
        new = new.encode()
    if old.lower() in obj:
        return obj.replace(old.lower(), new)
    elif old in obj:
        return obj.replace(old, new)
    if old.upper() in obj:
        return obj.replace(old.upper(), new)
    else:
        return obj


def replace_string_in_ldap_entry(dn, entry, old, new):
    # type: (str, Dict[str, List[bytes]], str, str) -> Tuple[str, Dict[str, List[bytes]]]
    def _replace(obj, old, new):  # type: (StrBytes, str, str) -> StrBytes
        obj = _replace_string(obj, old, new)
        # sorry, don't know why this did not work.
        _templ = "TEMPLATE" if isinstance(obj, str) else b"TEMPLATE"
        if _templ in obj:
            obj = obj.replace(old, new)
        return obj

    new_obj = {}
    new_dn = _replace(dn, old, new)
    for key, value in entry.items():
        new_key = _replace(key, old, new)
        new_obj[new_key] = [_replace(v, old, new) for v in value]

    return new_dn, new_obj


def status_bar(
    iteration,
    total,
    bar_length=100,
    out_stream=sys.stdout,
):
    percents = "{0:.2f}".format(100 * (iteration / float(total)))
    filled_length = int(round(bar_length * iteration / float(total)))
    bar = "+" * filled_length + "-" * (bar_length - filled_length)
    out_stream.write("\r%s| %s%s " % (bar, percents, "%")),
    if iteration == total:
        out_stream.write("\n")
    out_stream.flush()
