#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Univention UCS@school
#
# SPDX-FileCopyrightText: 2015-2026 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

import csv
import re
import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter

import ldap
from ldap.dn import escape_dn_chars
from ldap.filter import filter_format

import univention.admin.handlers.groups.group
import univention.admin.handlers.users.user
import univention.admin.objects
from univention.admin.uexceptions import base as UDMError, noObject
from univention.admin.uldap import getAdminConnection, getMachineConnection
from univention.management.console.config import ucr


class Problem(Exception):
    def __init__(self, line, *args, **kwargs):
        self.line = line
        super(Problem, self).__init__(*args, **kwargs)

    def fix(self, lo, po):
        pass


class InvalidGroupDetected(Problem):
    def __str__(self):
        correct, invalid_groups = self.args
        lines = []
        if not correct:
            lines.append(
                "ERROR: User %r is not a member of class %s"
                % (self.line["name"], self.line["school_class"])
            )
        for gdn in invalid_groups:
            lines.append("ERROR: User %r is unexpected member of %s" % (self.line["name"], gdn))
        return "\n".join(lines)

    def fix(self, lo, po):
        username = self.line["name"]
        user = univention.admin.handlers.users.user.lookup(
            None, lo, filter_format("uid=%s" % (username,)), required=True
        )[0]
        for group_dn in self.args[1]:
            group = univention.admin.objects.get(
                univention.admin.handlers.groups.group, None, lo, po, group_dn
            )
            group.open()
            if user.dn not in group["users"]:
                msg("ERROR: cannot remove %r from %r" % (user.dn, group.dn))
                continue
            group["users"].remove(user.dn)
            group.modify()
            msg("FIXED: removed %r from %r" % (user.dn, group.dn))


class StudentIsInAnotherSchool(Problem):
    def __str__(self):
        return "ERROR: User %r is not in school %s. DN: %s" % (
            self.args[0],
            self.line["school"],
            self.args[1],
        )


class StudentDoesNotExists(Problem):
    def __str__(self):
        return "ERROR: User %r was not found." % (self.args[0],)


class StudentIsInAnotherClassInAnotherSchool(Problem):
    def __str__(self):
        return "ERROR: User %r (%s) is group member in another school: %s" % self.args


def msg(*a):
    sys.stdout.flush()
    sys.stderr.write("%s\n" % a)
    sys.stderr.flush()


def main():
    description = """ucs-school-verify-class-memberships checks the class group membership of
students object in LDAP against the memberships defined in a specified CSV
import file. Found differences are printed to stderr. Progress information
is printed to stdout. The script does not alter the LDAP - the test is
performed read-only.

Example for showing only errors:
   ucs-school-verify-class-memberships students.csv > /dev/null"""
    parser = ArgumentParser(description=description, formatter_class=RawDescriptionHelpFormatter)
    parser.add_argument(
        "-l",
        "--csv-line",
        dest="columns",
        help="Defines the fields of the CSV-file (Default: %(default)s)",
        default="action,name,firstname,lastname,school,school_class",
    )
    parser.add_argument("--fix", action="store_true", help="Automatically repair the found problems.")
    parser.add_argument("filename", help="Specifies the path of the CSV file")
    ns = parser.parse_args()

    try:
        try:
            lo, po = getAdminConnection()
        except (ldap.LDAPError, UDMError, IOError):  # pylint: disable=E1101
            lo, po = getMachineConnection()
    except (ldap.LDAPError, UDMError, IOError) as exc:  # pylint: disable=E1101
        msg("LDAP Error: %s" % (exc,))
        sys.exit(2)

    code = 0
    columns = ns.columns.split(",")
    try:
        with open(ns.filename, "rb") as fd:
            lines = fd.readlines()
    except (IOError, OSError) as exc:
        msg("Could not open file %r: %s" % (ns.filename, exc))
        code = 2
    else:
        for problem in parse(lo, lines, columns):
            code = 1
            msg(str(problem))
            if not ns.fix:
                continue
            try:
                problem.fix(lo, po)
            except (ldap.LDAPError, UDMError) as exc:  # pylint: disable=E1101
                msg("LDAP Error: %s: %s" % (type(exc), exc))
    sys.exit(code)


def parse(lo, lines, columns):
    reader = csv.DictReader(lines, columns, delimiter="\t")
    for line in reader:
        try:
            parse_line(lo, line)
        except (ldap.LDAPError, UDMError) as exc:  # pylint: disable=E1101
            msg("LDAP Error: %s: %s" % (type(exc), exc))
        except Problem as exc:
            yield exc


def parse_line(lo, line):
    oubase = "ou=%s,%s" % (escape_dn_chars(line["school"]), ucr["ldap/base"])
    uid = line["name"]
    try:
        dn = lo.searchDn(filter_format("uid=%s", (uid,)), oubase, unique=True)[0]
    except (IndexError, noObject):
        try:
            dn = lo.searchDn(filter_format("uid=%s", (uid,)), ucr["ldap/base"], unique=True)[0]
        except (IndexError, noObject):
            if line["action"].upper() == "D":
                return
            raise StudentDoesNotExists(line, uid)
        else:
            raise StudentIsInAnotherSchool(line, uid, dn)
    if not dn.endswith(",cn=schueler,cn=users,%s" % (oubase,)):
        if not dn.endswith(",cn=lehrer,cn=users,%s" % (oubase,)) or not dn.endswith(
            ",cn=mitarbeiter,cn=users,%s" % (oubase,)
        ):
            print("Ignoring teacher/staff %r" % (uid,))
            return
        msg(
            "ERROR: %s (%s %s) is not a student/teacher/staff."
            % (uid, line["firstname"], line["lastname"])
        )
        return
    if line["action"].upper() == "D":
        msg("ERROR: User %r (%s) should not exist (but does)." % (uid, dn))
    print("Found user %s as %s " % (uid, dn))
    groups = lo.search(filter_format("uniqueMember=%s", (dn,)), ucr["ldap/base"])
    correct = False
    invalid_groups = set()
    for gdn, _group in groups:  # pylint: disable=W0612
        if not gdn.endswith(",cn=klassen,cn=schueler,cn=groups,%s" % (oubase,)):
            if not gdn.endswith(oubase) and re.search(
                r",ou=[^,]+,%s$" % (ucr["ldap/base"],), gdn, re.IGNORECASE
            ):
                raise StudentIsInAnotherClassInAnotherSchool(line, uid, dn, gdn)
            continue  # ignore workgroups / Domain Users
        if gdn.startswith("cn=%s," % (line["school_class"],)):
            correct = True
        else:
            invalid_groups.add(gdn)
    if not correct or invalid_groups:
        raise InvalidGroupDetected(line, correct, invalid_groups)


if __name__ == "__main__":
    main()
