#!/usr/bin/env python3
#-*- coding: utf-8 -*-
import subprocess
import json
from datetime import datetime, timedelta, date, time
import os
import sys
from configparser import ConfigParser
import argparse
import logging
from icalendar import Calendar, Event, vDDDTypes
from calendar import monthrange

logger = logging.getLogger("flag-switch")
handler = logging.StreamHandler(stream=sys.stdout)

class FlagNotConfiguredError(BaseException):
    def __init__(self, message):
        super().__init__(message)
        self.__suppress_context__ = True
        sys. tracebacklimit = 0

class FlagPlanningError(BaseException):
    def __init__(self, message):
        super().__init__(message)
        self.__suppress_context__ = True
        sys. tracebacklimit = 0

def schedule_from_json(file, name):
    """
    Return json file content as object.
    """

    def dual_events(event):
        "Return on events equivalent to off event"
        if event[2] == '00:00' and event[3] == '24:00':
            return [(event[0], 'clear', event[2], event[3])]
        events = []
        if event[2] == '00:00':
            events.append((event[0], 'on', event[3], '24:00'))
        elif event[3] == '24:00':
            events.append((event[0], 'on', '00:00', event[2]))
        else:
            events.append((event[0], 'on', '00:00', event[2]))
            events.append((event[0], 'on', event[3], '24:00'))
        return events

    def add_component(schedule, event, ref_date):
        h_start, m_start = [int(t) for t in event[2].split(":")]
        h_end, m_end = [int(t) for t in event[3].split(":")]
        vevent = Event()
        vevent.add('summary', event[1])
        vevent.add('dtstamp', ref_date)

        if frequency != "unique":
            vevent.add('rrule', {'freq': frequency})
            if frequency == "yearly":
                month, day = [int(t) for t in event[0].split("-")]
                if day > monthrange(ref_date.year, month)[1]:
                    return

                start_date = ref_date.replace(month=month,
                                                  day=day)
                if h_end == 24:
                    end_date = start_date + timedelta(days=1)
                else:
                    end_date = start_date.replace(hour=h_end,
                                                minute=m_end)
                start_date = start_date.replace(hour=h_start,
                                                minute=m_start)
            if frequency == "monthly":
                day = int(event[0])
                if day > monthrange(ref_date.year, ref_date.month)[1]:
                    return
                start_date = ref_date.replace(day=day)
                if h_end == 24:
                    end_date = start_date + timedelta(days=1)
                else:
                    end_date = start_date.replace(hour=h_end,
                                                minute=m_end)
                start_date = start_date.replace(hour=h_start,
                                                minute=m_start)
            if frequency == "weekly":
                day = int(event[0])
                start_date = ref_date + timedelta(days=day - ref_date.isoweekday())
                if h_end == 24:
                    end_date = start_date + timedelta(days=1)
                else:
                    end_date = start_date.replace(hour=h_end,
                                                minute=m_end)
                start_date = start_date.replace(hour=h_start,
                                                minute=m_start)
        else:
            start_date = datetime.strptime(event[0], "%Y-%m-%d")
            end_date = start_date
            start_date = start_date.replace(hour=h_start,
                                            minute=m_start)
            if h_end == 24:
                end_date = end_date + timedelta(days=1)
            else:
                end_date = end_date.replace(hour=h_end,
                                            minute=m_end)
        vevent.add('dtstart', start_date)
        vevent.add('dtend', end_date)
        getattr(schedule, frequency).add_component(vevent)

    logger.info("Loading dictionnary from json file")
    schedule = Schedule(name)
    with open(file, "r") as json_fh:
        raw_schedule = json.load(json_fh)
    for frequency in ("unique", "yearly", "monthly", "weekly"):
        for event in raw_schedule.get(frequency, ()):
            ref_date = datetime.now().replace(hour=0, minute=0, second=0)

            if event[1] in ['on']:
                add_component(schedule, event, ref_date)
            elif event[1] in ['off']:
                for event in dual_events(event):
                    add_component(schedule, event, ref_date)

    
    return schedule


def schedule_from_ics(file, name):
    """Return ics content as dictionnary.
    """
    schedule = Schedule(name)
    with open(file, "r") as ics_fh:
        calendar = Calendar.from_ical(ics_fh.read())
    for event in calendar.walk('VEVENT'):
        recurrence = event.get('RRULE', None)
        if not recurrence:
            schedule.unique.add_component(event)
        else:
            frequency = recurrence.get('FREQ')[0]
            getattr(schedule, str(frequency).lower()).add_component(event)

    return schedule


SUPPORTED_FORMAT = {
        "json": schedule_from_json,
        "ics": schedule_from_ics,
        }


def project(event, day):
    dtstart = vDDDTypes.from_ical(event.get('DTSTART').to_ical().decode("utf-8"))
    dtend = vDDDTypes.from_ical(event.get('DTEND').to_ical().decode("utf-8"))
    logger.debug(f"Event has dtstart {dtstart} and dtend {dtend}")
    day_start = day.replace(hour=0, minute=0, second=0, microsecond=0)
    day_end = day_start + timedelta(days=1)
    if not isinstance(dtstart, datetime):
        dtstart = datetime.combine(dtstart, time(hour=0, minute=0, second=0, microsecond=0))
        dtend = datetime.combine(dtend, time(hour=0, minute=0, second=0, microsecond=0)) + timedelta(days=1)
    if dtstart < day_end and dtend > day_start:
        projected_start = max(day_start, dtstart)
        projected_end = min(day_start + timedelta(days=1), dtend)
        return projected_start, projected_end
    else:
        logger.debug("No overlapping between day and event")
        return None, None


class Schedule:
    """Representation of schedule related to flag managment
    """
    def __init__(self, name):
        self.name = name
        self.unique = Calendar()
        self.yearly = Calendar()
        self.monthly = Calendar()
        self.weekly = Calendar()
        self.daily = Calendar()

    def schedule(self, day=datetime.now(), exclusive=True):
        """Extracting schedule from events previously loaded in attributes.
        Ranges are converted to in and out events and overlaping events are
        simplified.
        """
        logger.info("Selecting events in schedule for specific date")
        # Select events based on now date
        events = []
        for schedule_periodicity, date_format in [("unique", "%Y-%m-%d"),
                                                  ("yearly", "%m-%d"),
                                                  ("monthly","%d"),
                                                  ("weekly","%u")]:
            schedule = getattr(self, schedule_periodicity)
            if schedule.is_empty():
                logger.debug(f"{schedule_periodicity} calendar is empty")
                continue
            for event in schedule.walk('VEVENT'):
                in_time, out_time = project(event, day)
                logger.debug(f"{in_time}, {out_time}")
                if not in_time:
                    continue
                if event['summary'] == 'clear':
                    break
                events.append((in_time, out_time))
            else:
                if events and exclusive:
                    break
                continue
            break
        events.sort()
        logger.debug(f"events from schedule: {events}")
        # Clean events and split in and out
        all_switches = []
        switches = []
        in_range = []
        for index, event in enumerate(events):
            all_switches.append((event[0], index, 'in'))
            if event[1].time() != time(hour=0, minute=0):
                all_switches.append((event[1], index, 'out'))

        if not all_switches:
            day_start = day.replace(hour=0, minute=0, second=0, microsecond=0)
            day_end = day_start + timedelta(days=1)
            return [(day_start, "out"), (day_end, "in")]

        all_switches.sort()

        switches.append(all_switches[0])
        in_range.append(all_switches[0][1])
        for switch in all_switches[1:]:
            moment, event, switch_type = switch
            prev_time, prev_event, prev_switch_type = switches[-1]

            if switch_type == 'in':
                in_range.append(event)
                if prev_switch_type == 'out':
                    if prev_time == moment:
                        switches.pop(-1)
                    elif prev_time < moment:
                        switches.append(switch)
            elif switch_type == 'out':
                in_range.pop(in_range.index(event))
                if prev_switch_type == 'in' and not in_range:
                    switches.append(switch)
        switches = [(sw[0], sw[2]) for sw in switches]
        if switches[0][0].time() != time(hour=0, minute=0):
            switches = [(day.replace(hour=0, minute=0, second=0, microsecond=0),
                         'in' if switches[0][1] == 'out' else 'out') ] + switches
        if switches[-1][0].time() != time(hour=0, minute=0):
            switches.append((day.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1),
                            'in' if switches[-1][1] == 'out' else 'out'))

        return switches


class Flag:
    """Representation of a flag with methods to retrieve scheduled from associated file
    and effectively schedule via system at command.
    """
    def __init__(self, name, flag_path, store_path, present_on_event, store_format='json'):
        self.name = name
        self.flag_path = flag_path
        self.path = store_path
        self.format = store_format
        self.present_on_event = present_on_event
        self.schedule = self.get_schedule()

    def get_schedule(self):
        """Retrieve schedule from associated file
        """
        if not os.path.exists(self.path):
            raise FlagPlanningError(f"File not found {self.path}")
        if self.format not in SUPPORTED_FORMAT:
            raise FlagPlanningError(f"Format {self.format} not supported")
        return SUPPORTED_FORMAT[self.format](self.path, self.name)

    def in_schedule(self):
        switches = self.schedule.schedule(exclusive=True)
        now = datetime.now()
        for switch in switches:
            if switch[0] < now:
                if switch[1] == 'in':
                    if self.present_on_event:
                        temp_result = os.path.exists(self.flag_path)
                    else:
                        temp_result = not os.path.exists(self.flag_path)
                else:
                    if self.present_on_event:
                        temp_result = not os.path.exists(self.flag_path)
                    else:
                        temp_result = os.path.exists(self.flag_path)
                continue
            if switch[1] == 'in':
                if self.present_on_event:
                    return not os.path.exists(self.flag_path)
                else:
                    return os.path.exists(self.flag_path)
            else:
                if self.present_on_event:
                    return os.path.exists(self.flag_path)
                else:
                    return not os.path.exists(self.flag_path)
        return temp_result

    def display_schedule(self):
        def status(sw, reverse=False):
            if reverse:
                sw = 'in' if sw == 'out' else 'out'
            if sw == 'in':
                if self.present_on_event:
                    return 'on'
                else:
                    return 'off'
            else:
                if self.present_on_event:
                    return 'off'
                else:
                    return 'on'

        switches = self.schedule.schedule(exclusive=True)
        logger.debug(f"switches to display: {switches}")
        dual = []
        for index, switch in enumerate(switches):
            if index == 0 and switch[0].time() != time(hour=0, minute=0):
                logger.debug(f"prepend range")
                dual.append(("00:00", f"{switch[0]}", status(switch[1], reverse=True)))
            if len(switches) == index + 1:
                if switch[0].time() != time(hour=0, minute=0) or len(switches) == 1:
                    logger.debug(f"append range")
                    dual.append((f"{switch[0].strftime('%H:%M')}", "00:00", status(switch[1])))
            else:
                dual.append((f"{switch[0].strftime('%H:%M')}", f"{switches[index+1][0].strftime('%H:%M')}", status(switch[1])))

        logger.debug(f"dual representation: {dual}")
        return dual


    def schedule_at(self, dry_run=False):
        """Extract events from schedule
        """
        def at_simplist_wrapper(command, time, dry_run=False):
            logger.info("Creating at command tasks")
            if time != 'now':
                time = f"{time.strftime('%H:%M')} today"
            logger.debug(f"{command} {time}")
            if not dry_run:
                tmpfile = '/tmp/at_file'
                with open(tmpfile, 'w') as at_fh:
                    at_fh.write(command)
                subprocess.call(['at', '-f', tmpfile, time])
                os.unlink(tmpfile)
            else:
                logger.info("Mode dry-run")

        switches = self.schedule.schedule(exclusive=True)
        logger.debug(f"switches from schedule: {switches}")

        logger.info("Cleaning past switches")
        full_len_switches = len(switches)
        time = datetime.now()
        purged_switches = [sw for sw in switches if sw[0] > time]
        logger.debug(f"switches after cleaning past ones: {purged_switches}")
        if len(purged_switches) != full_len_switches:
            if purged_switches:
                purged_switches = [('now', 'in' if purged_switches[0][1] == 'out' else 'out')] + purged_switches
            else:
                purged_switches = [('now', 'in' if switches[-1][1] == 'in' else 'out')]

        logger.debug(f"switches after adding extrems: {purged_switches}")
        logger.info(f"Programmation pour le drapeau {self.name}")
        if self.present_on_event:
            logger.info("Drapeau présent dans les plages horaires spécifiées")
            for event in purged_switches[:-1]:
                cmd = 'touch' if event[1] == 'in' else 'rm -f'
                command = f"{cmd} {self.flag_path}"
                logger.debug(f"command for switch {event}: {cmd}")
                at_simplist_wrapper(command, event[0], dry_run=dry_run)
        else:
            logger.info("Drapeau absent dans les plages horaires spécifiées")
            for event in purged_switches[:-1]:
                cmd = 'rm -f' if event[1] == 'in' else 'touch'
                command = f"{cmd} {self.flag_path}"
                at_simplist_wrapper(command, event[0], dry_run=dry_run)


def valid_config(config, permissive=False):
    logger.info("Validating configuration")
    if not permissive:
        assert config.has_section('Global')
        assert config.has_option('Global', 'path')
        for section in config.sections():
            if section != 'Global':
                assert config.has_option(section, 'mode')
                mode = config.get(section, 'mode')
                assert mode in ['manuelle', 'fichier']
                if mode == 'fichier':
                    assert config.has_option(section, 'path')
                    assert config.has_option(section, 'format')
                    assert config.has_option(section, 'present_on_event')
    else:
        if not config.has_section('Global'):
            logger.error("Missing section Global")
        else:
            if not config.has_option('Global', 'path'):
                logger.error("Missing option path in section Global")
        for section in config.sections():
            if section != 'Global':
                if not config.has_option(section, 'mode'):
                    logger.error(f"Missing option mode in section {section}")
                else:
                    mode = config.get(section, 'mode')
                    if mode not in ['manuelle', 'fichier']:
                        logger.error(f"Unsupported mode {mode}")
                    else:
                        if mode == 'fichier':
                            if not config.has_option(section, 'path'):
                                logger.error(f"Missing option path in section {section}")
                            if not config.has_option(section, 'format'):
                                logger.error(f"Missing option format in section {section}")
                            if not config.has_option(section, 'present_on_event'):
                                logger.error(f"Missing option present_on_event in section {section}")

def validate(args):
    config = args.config
    valid_config(config)

def schedule(args):
    logger.info("Scheduling for configured flags")
    args.name = None
    unschedule(args)
    config = args.config
    for flag_name in config.sections():
        if flag_name == 'Global' or config.get(flag_name, 'mode') == 'manuelle':
            continue
        present_on_event = bool(int(config.get(flag_name, 'present_on_event')))
        if config.get(flag_name, 'mode') == 'fichier':
            if not config.has_option(flag_name, 'path'):
                logger.warning(f"Schedule not set for {flag_name}.")
                continue
            flag_path = config.get(flag_name, 'path')
            if not os.path.exists(flag_path):
                logger.warning(f"Schedule source {flag_path} not available for {flag_name}.")
                continue
            flag_format = config.get(flag_name, 'format')
            flags_path = config.get('Global', 'path')
            path = os.path.join(flags_path, flag_name)
            flag = Flag(flag_name, path, config.get(flag_name, 'path'), present_on_event, store_format=flag_format)
            flag.schedule_at()


def display(args):
    def display_one(flag_name):
        present_on_event = bool(int(config.get(flag_name, 'present_on_event')))
        if config.get(flag_name, 'mode') == 'fichier':
            flag_path = config.get(flag_name, 'path')
            if os.path.exists(flag_path):
                flag_format = config.get(flag_name, 'format')
                flags_path = config.get('Global', 'path')
                path = os.path.join(flags_path, flag_name)
                flag = Flag(flag_name, path, config.get(flag_name, 'path'), present_on_event, store_format=flag_format)
                print(flag_name, flag.display_schedule())
            else:
                print(f"No schedule for flag {flag_name}")
                logger.error(f"No schedule for flag {flag_name}")

    config = args.config
    if not args.name:
        for flag_name in config.sections():
            if flag_name == 'Global' or config.get(flag_name, 'mode') == 'manuelle':
                continue
            display_one(flag_name)
    else:
        if args.name in config.sections():
            display_one(args.name)
        else:
            print(f"No such flag {args.name} configured")
            logger.error(f"No such flag {args.name} configured")


def list_flags(config):
    root_path = config.get('Global', 'path')
    return [os.path.join(root_path, section) for section in config.sections() if section != 'Global']


def at_filtered_list(flags):
    filtered_at_jobs = []
    at_jobs = [aj.split('\t')[0] for aj in subprocess.check_output(['atq']).decode('utf-8').strip().split('\n') if aj]
    for at_job in at_jobs:
        cmd = subprocess.check_output(['at', '-c', at_job]).decode('utf-8').strip().split('\n')[-1]
        if cmd.split(' ')[1] in flags:
            filtered_at_jobs.append(at_job)
    return filtered_at_jobs


def unschedule(args):
    logger.info("Unscheduling remaining tasks for configured flags")
    config = args.config
    if args.name:
        flags = [os.path.join(config.get('Global', 'path'), args.name)]
    else:
        flags = list_flags(args.config)
    for at_job in at_filtered_list(flags):
        subprocess.check_call(['atrm', at_job])


def clean(args):
    logger.info("Deleting all flags")
    args.name = None
    unschedule(args)
    for flag in list_flags(args.config):
        if os.path.exists(flag):
            os.unlink(flag)


def test_flag(args):
    if not args.config.has_section(args.name):
        raise FlagNotConfiguredError(f"flag {args.name} not configured")
    flag_name = args.name
    mode = args.config.get(args.name, 'mode')
    location = os.path.join(args.config.get("Global", "path"), args.name)
    if mode == "manuelle":
        report = {"name": args.name, "mode": mode, "location": location, "applied": os.path.exists(location)}
    else:
        present_on_event = bool(int(config.get(flag_name, 'present_on_event')))
        flag_format = config.get(flag_name, 'format')
        flags_path = config.get('Global', 'path')
        path = os.path.join(flags_path, flag_name)
        flag = Flag(flag_name, path, config.get(flag_name, 'path'), present_on_event, store_format=flag_format)
        planning = flag.display_schedule()
        applied = flag.in_schedule()
        report = {"mode": mode, "location": location, "planning": planning, "applied": applied}

    logger.debug(json.dumps(report))
    return report

def output_test_flag(args):
    output = test_flag(args)
    print(json.dumps(output))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", help="Configuration path", required=True)
    parser.add_argument("-l", "--log-level", help="Log level")

    subparsers = parser.add_subparsers(help='sub-command help')

    validate_subparser = subparsers.add_parser("validate", help="Validate global configuration and flags configuration")
    validate_subparser.set_defaults(func=validate)

    test_subparser = subparsers.add_parser("test", help="Test flag state")
    test_subparser.set_defaults(func=output_test_flag)
    test_subparser.add_argument("-n", "--name", help="Name of the flag", required=True)

    schedule_subparser = subparsers.add_parser("schedule", help="Schedule flags creation and deletion")
    schedule_subparser.set_defaults(func=schedule)

    display_subparser = subparsers.add_parser("display", help="Display schedule of the day")
    display_subparser.add_argument("-n", "--name", help="Name of the flag")
    display_subparser.set_defaults(func=display)

    unschedule_subparser = subparsers.add_parser("unschedule", help="Clean at jobs related to flags managment")
    unschedule_subparser.set_defaults(func=unschedule)
    unschedule_subparser.add_argument("-n", "--name", help="Name of the flag")

    clean_subparser = subparsers.add_parser("clean", help="Delete configured flags")
    clean_subparser.set_defaults(func=clean)

    args = parser.parse_args()

    logger.addHandler(handler)
    logger.setLevel(logging.WARNING) # initial level for configuration validation code

    config = ConfigParser()
    config.read_file(open(args.config))
    valid_config(config)

    if args.log_level:
        logger.setLevel(getattr(logging, args.log_level))
    elif config.has_option("Global", "loglevel"):
        loglevel = config.get("Global", "loglevel").upper()
        if loglevel in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"):
            logger.setLevel(getattr(logging, loglevel))
    else:
        logger.setLevel(logging.WARNING) # default level

    args.config = config

    args.func(args)
