"""Event module."""
import spiceypy as spice
import spiceypy.utils.support_types as stypes
import numpy as np
# import csv
NA_REL_TIME = '+00:00:00'
[docs]class Event():
    """A class representing an event.
    The Event class is a wrapper class around the SPICE Window type that is at the
    core of SPICE Geometry Finder sub-system. It simplify operations for setting and
    getting Event properties.
    An Event (or Window) is made of one or several intervals, and an interval can be
    an instant.
    Attributes:
        time_window (str): SPICE Window data type
        interval_times (list): event intervals times, in ephemeris time (et)
        n_intervals (int): number of time intervals within the event
        start_time (float): start time of the event, in ephemeris time (et)
        stop_time (float): start time of the event, in ephemeris time (et)
        duration (float): duration of the event, in seconds
        reference_time (float): time at equal time from start and end times, in ephemeris time (et)
    """
    def __init__(self, time_window=None, start_time=None, stop_time=None, format='', data=None, id=''):
        """Inits Event object.
        Args:
            time_window (spiceypy.utils.support_types.SpiceCell):
            start_time:
            stop_time:
            format:
        """
        self.interval_times = []  # event intervals times, in ephemeris time (et)
        self.n_intervals = 0  # number of time intervals within the event
        self.start_time = 0  # start time of the event, in ephemeris time (et)
        self.stop_time = 0  # stop time of the event, in ephemeris time (et)
        self.duration = 0  # duration of the event, in seconds
        self.reference_time = 0  # at equal time from start and end times, in ephemeris time (et)
        self.time_window = None  # SPICE Window data type
        self.data = data  # carrying extra read from ESA event files
        self.id = id  # optional id, for events handling within MOS.
        # set default format for inputs start and stop times
        if format == '': format = 'et'
        if time_window is None:
            if start_time is not None:
                if stop_time is None: stop_time = start_time
                self.set_times(start_time, stop_time, format=format)
                # TODO: prevent error to be thrown if invalid format used.
        else:
            self.set_time_window(time_window)
        if self.duration > 0:  # set reference time to start_time for intervals
            self.reference_time = self.start_time
    def __repr__(self):
        s = (
            f"<{self.__class__.__name__}> "
            f"UTC start time: {self.get_start_time(format='utc')} | "
            f"UTC stop time: {self.get_stop_time(format='utc')} | "
            f"UTC reference time: {self.get_reference_time(format='utc')} | "
            f"Duration (sec): {self.duration:.3f} | "
            f"Nb of intervals: {self.n_intervals}"
        )
        utc_interval_times = self.get_interval_times(format='utc')
        for utc_interval_time in utc_interval_times:
            s += f'\n - {utc_interval_time}'
        return s
    def __str__(self):
        s = '\n'
        if self.id:
            s = f' Event ID: {self.id}\n'
        s += f" UTC start time: {self.get_start_time(format='utc')}"
        if self.duration > 0:
            s += f"\n UTC stop time: {self.get_stop_time(format='utc')}"
            if self.reference_time != self.start_time:
                s += f"\n UTC reference time: {self.get_reference_time(format='utc')}"
        if self.duration > 60:
            duration_str = self.get_duration(format='str')
            s += f"\n Duration (HH:MN:SS): {duration_str}"
        else:
            duration_str = f'{self.duration:.3f}'
            s += f"\n Duration (sec): {duration_str}"
        if self.n_intervals > 1:
            utc_interval_times = self.get_interval_times(format='utc')
            s += f'\n Nb of intervals: {self.n_intervals}'
            for utc_interval_time in utc_interval_times:
                s += f'\n - {utc_interval_time}'
        if self.data:
            s += f'\n Source event data:'
            for key in self.data:
                s += f'\n - {key}: {self.data[key]}'
        return s
[docs]    def get_dict(self):
        event_dict = {
            'start_time': self.get_start_time(format='utc'),
            'stop_time': self.get_stop_time(format='utc'),
            'n_intervals': self.n_intervals,
            # 'reference_time': self.reference_time,
            'interval_times': self.get_interval_times(format='utc')
        }
        return event_dict
 
[docs]    def set_time_window(self, time_window):
        self.time_window = time_window
        self.n_intervals = spice.wncard(time_window)
        if self.n_intervals > 0:
            interval_times = []
            for i in range(self.n_intervals):
                t1, t2 = spice.wnfetd(time_window, i)
                interval_times.append([t1, t2])
            self.start_time = np.min(interval_times)
            self.stop_time = np.max(interval_times)
            self.duration = self.stop_time - self.start_time
            self.reference_time = self.start_time + 0.5 * self.duration
            self.interval_times = interval_times
 
        # else:
        #     print('WARNING: Input SPICE Window contains no time intervals.')
[docs]    def set_times(self, start_time, stop_time, format='et'):
        if format == 'utc':
            self.start_time = spice.str2et(start_time)
            self.stop_time = spice.str2et(stop_time)
        elif format == 'et':
            self.start_time = start_time
            self.stop_time = stop_time
        else:
            print('set_times: Invalid input time format: ' + format)
        time_window = stypes.SPICEDOUBLE_CELL(2)
        spice.wninsd(self.start_time, self.stop_time, time_window)
        self.time_window = time_window
        self.duration = self.stop_time - self.start_time
        self.reference_time = self.start_time # + 0.5 * self.duration
        self.n_intervals = 1
        self.interval_times = [[self.start_time, self.stop_time]]
 
[docs]    def set_interval_times(self, interval_times, format='et'):
        self.n_intervals = len(interval_times)
        if format == 'utc':
            self.interval_times = []
            for i in range(self.n_intervals):
                self.interval_times.append([spice.str2et(interval_times[i][0]), spice.str2et(interval_times[i][1])])
        elif format == 'et':
            self.interval_times = interval_times
        self.start_time = np.min(self.interval_times)
        # self.stop_time = np.max(self.interval_times)
        interval_time_step = 0.0
        if self.n_intervals > 1:
            interval_time_step = self.interval_times[1][0] - self.interval_times[0][0]
        self.stop_time = np.max(self.interval_times) +  interval_time_step
        self.duration = self.stop_time - self.start_time # + interval_time_step
        self.reference_time = self.start_time # + 0.5 * self.duration
        time_window = stypes.SPICEDOUBLE_CELL(2 * self.n_intervals)
        for interval_time in self.interval_times:
            spice.wninsd(interval_time[0], interval_time[1], time_window)
        self.time_window = time_window
 
[docs]    def get_time_window(self):
        return self.time_window
 
[docs]    def get_interval_times(self, format='et', mission_event=None):
        if format == 'et':
            return self.interval_times
        elif format == 'utc':
            utc_intervals = []
            for interval in self.interval_times:
                utc_interval = [spice.timout(interval[0], "YYYY-MM-DD HR:MN:SC.###"),
                                spice.timout(interval[1], "YYYY-MM-DD HR:MN:SC.###")]
                utc_intervals.append(utc_interval)
            return utc_intervals
        elif format == 'rel':
            if mission_event:
                et_ref_time = mission_event.get_reference_time(format='et')
                rel_intervals = []
                for interval in self.interval_times:
                    rel_interval = [self.to_string(interval[0] - et_ref_time), self.to_string(interval[1] - et_ref_time)]
                    rel_intervals.append(rel_interval)
                return rel_intervals
            else:
                rel_intervals = []
                for interval in self.interval_times:
                    rel_intervals.append([NA_REL_TIME, NA_REL_TIME])
                return rel_intervals
        else:
            print('get_interval_times: Invalid input time format: ' + format)
            return None
 
[docs]    def get_start_time(self, format='et', mission_event=None):
        if format == 'et':
            return self.start_time
        elif format == 'utc':
            return spice.timout(self.start_time, "YYYY-MM-DD HR:MN:SC.###")
        elif format == 'YYYYMMDD':
            return spice.timout(self.start_time, "YYYYMMDD")
        elif format == 'rel':
            if mission_event:
                return self.to_string(self.start_time - mission_event.get_reference_time(format='et'))
            else:
                return NA_REL_TIME
        else:
            print('get_start_time: Invalid input time format: ' + format)
            return None
 
[docs]    def get_stop_time(self, format='et', mission_event=None):
        if format == 'et':
            return self.stop_time
        elif format == 'utc':
            return spice.timout(self.stop_time, "YYYY-MM-DD HR:MN:SC.###")
        elif format == 'YYYYMMDD':
            return spice.timout(self.stop_time, "YYYYMMDD")
        elif format == 'rel':
            if mission_event:
                return self.to_string(self.stop_time - mission_event.get_reference_time(format='et'))
            else:
                return NA_REL_TIME
        else:
            print('get_stop_time: Invalid input time format: ' + format)
            return None
 
[docs]    def get_reference_time(self, format='et', mission_event=None):
        if format == 'et':
            return self.reference_time
        elif format == 'utc':
            return spice.timout(self.reference_time, "YYYY-MM-DD HR:MN:SC.###")
        elif format == 'rel':
            if mission_event:
                return self.to_string(self.reference_time - mission_event.get_reference_time(format='et'))
            else:
                return NA_REL_TIME
        else:
            print('get_reference_time: Invalid input time format: ' + format)
            return None
 
[docs]    def get_duration(self, format=None):
        if format == 'str':
            return self.to_string(self.duration)[1:]
        else:
            return self.duration
 
[docs]    def to_string(self, rel_time):
        """Convert an a relative time in seconds to a string representation.
        For example::
            43801.51 ->  12:10:01.510
            -14172.2 -> -03:56:12.200
        Args:
            rel_time (float): relative time in seconds
        Returns:
            str: string representating of a relative time
        """
        # init number of seconds and millisecs
        s, dec = (lambda x, y: (int(x), int(x * y) % y / y))(abs(rel_time), 1e3)
        millisecs = dec * 1000
        if rel_time != 0.0:  # sign representation
            sign_str = '+' if (s / rel_time) > 0 else '-'
        else:
            sign_str = ''
        hours = s // 3600  # hours
        s = s - (hours * 3600)  # remaining seconds
        minutes = s // 60  # minutes
        seconds = s - (minutes * 60)  # remaining seconds
        ref_time_str = '{}{:02}:{:02}:{:02}.{:03}'.format(sign_str, int(hours), int(minutes), int(seconds),
                                                          int(millisecs))
        return ref_time_str
 
[docs]    def to_seconds(self, rel_time_str):
        """Convert an input string representation of a relative time to seconds.
        For example::
            12:10:01.510 -> 43801.51
            -03:56:12    -> -14172.0
        Args:
            rel_time_str (str): string representating of a relative time
        Returns:
            float: relative time in seconds
        """
        values = rel_time_str.split(':')
        if len(values) != 3:
            print('WARNING: Invalid input relative time string representation (+/-HH:MM:SS.###).')
            return None
        else:
            hours = float(values[0])
            sign = 1.0
            if values[0][0] == '-':
                sign = -1.0
            if hours != 0.0:
                sign = abs(hours) / hours
            hours = abs(hours)
            minutes = float(values[1])
            seconds = float(values[2])
            return sign * (hours * 3600.0 + minutes * 60.0 + seconds)
 
[docs]    def summary(self):
        print('Time Window : ')
        print('  - Start Time = {:s}'.format(self.get_start_time(format='utc')))
        print('  - Stop Time  = {:s}'.format(self.get_stop_time(format='utc')))
        print('  - Duration   = {:f}'.format(self.duration))
        print('Number of Intervals = {:d}'.format(self.n_intervals))
  
        # if self.n_intervals > 0:
        #     for i, interval_time in enumerate(self.interval_times):
        #         print('Time Interval #{:d} :'.format(i))
        #         print('  - Start Time = {:f}'.format(interval_time[0]))
        #         print('  - Stop Time  = {:f}'.format(interval_time[1]))
        #         print('  - Duration   = {:f}'.format(interval_time[1]-interval_time[0]))
# class EventFile:
#     """Class to read and write event files.
#     """
#     def __init__(self, event_file=None):
#         """ Inits EventFile object.
# 
#         Args:
#             event_file:
#         """
#         self.event_file = ''
#         self.events = []
#         if event_file is not None:
#             self.event_file = event_file
#             self.read()
# 
#     def read(self, event_file=None):
#         if event_file is not None:
#             self.event_file: event_file
# 
#         self.events = []
#         with open(self.event_file, 'r') as file:
#             csv_file = csv.DictReader(file)
#             for row in csv_file:
#                 self.events.append(dict(row))
# 
#         with open(self.event_file, 'r') as file:
#             csv_file = csv.DictReader(file)
#             for row in csv_file:
#                 self.events.append(dict(row))
# 
#     def __repr__(self):
#         return '<%s %r>' % (self.__class__.__name__, self.__dict__)
# 
#     def get_events_ids(self):
#         return self.events  # actually dicts not objects... !
# 
#     def get_event(self, event_name) -> Event:
#         event = None
#         exist = False
#         for event_dict in self.events:
#             if event_dict['id'] == event_name:
#                 event = Event(
#                     start_time=event_dict['start_time (utc)'],
#                     stop_time=event_dict['stop_time (utc)'],
#                     format='utc'
#                 )
#                 if event.duration != 0:  # set reference time to start_time for segment (!= flyby CA)
#                     event.reference_time = event.start_time
#                 exist = True
#         return event, exist