#!/usr/bin/env python
# Software License Agreement (BSD License)
#
# Copyright (c) 2008, Willow Garage, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above
#    copyright notice, this list of conditions and the following
#    disclaimer in the documentation and/or other materials provided
#    with the distribution.
#  * Neither the name of Willow Garage, Inc. nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Revision $Id$

## Integration test node that subscribes to any topic and verifies
## the publishing rate to be within a specified bounds. The following
## parameters must be set:
##
##  * ~/hz: expected hz
##  * ~/hzerror: errors bound for hz
##  * ~/test_duration: time (in secs) to run test
## 

from __future__ import print_function

import sys
import threading
import time
import unittest

import rospy
import rostest

NAME = 'hztest'

from threading import Thread

class HzTest(unittest.TestCase):
    def __init__(self, *args):
        super(HzTest, self).__init__(*args)
        rospy.init_node(NAME)

        self.lock = threading.Lock()
        self.message_received = False

    def setUp(self):
        self.errors = []
        # Count of all messages received
        self.msg_count = 0
        # Time of first message received
        self.msg_t0 = -1.0
        # Time of last message received
        self.msg_tn = -1.0

    ## performs two tests of a node, first with /rostime off, then with /rostime on
    def test_hz(self):
        # Fetch parameters
        try:
            # expected publishing rate
            hz = float(rospy.get_param('~hz'))
            # length of test
            test_duration = float(rospy.get_param('~test_duration'))
            # topic to test
            topic = rospy.get_param('~topic')
            # time to wait before
            wait_time = rospy.get_param('~wait_time', 20.)            
        except KeyError as e:
            self.fail('hztest not initialized properly. Parameter [%s] not set. debug[%s] debug[%s]'%(str(e), rospy.get_caller_id(), rospy.resolve_name(e.args[0])))

        # We only require hzerror if hz is non-zero
        hzerror = 0.0
        if hz != 0.0:
            try:
                # margin of error allowed
                hzerror = float(rospy.get_param('~hzerror'))
            except KeyError as e:
                self.fail('hztest not initialized properly. Parameter [%s] not set. debug[%s] debug[%s]'%(str(e), rospy.get_caller_id(), rospy.resolve_name(e.args[0])))

        # We optionally check each inter-message interval
        try:
            self.check_intervals = bool(rospy.get_param('~check_intervals'))
        except KeyError:
            self.check_intervals = False

        # We optionally measure wall clock time
        try:
            self.wall_clock = bool(rospy.get_param('~wall_clock'))
        except KeyError:
            self.wall_clock = False

        print("""Hz: %s
Hz Error: %s
Topic: %s
Test Duration: %s"""%(hz, hzerror, topic, test_duration))
        
        self._test_hz(hz, hzerror, topic, test_duration, wait_time)        
            
    def _test_hz(self, hz, hzerror, topic, test_duration, wait_time): 
        self.assert_(hz >= 0.0, "bad parameter (hz)")
        self.assert_(hzerror >= 0.0, "bad parameter (hzerror)")
        self.assert_(test_duration > 0.0, "bad parameter (test_duration)")
        self.assert_(len(topic), "bad parameter (topic)")

        if hz == 0:
            self.min_rate = 0.0
            self.max_rate = 0.0
            self.min_interval = 0.0
            self.max_interval = 0.0
        else:
            self.min_rate = hz - hzerror
            self.max_rate = hz + hzerror
            self.min_interval = 1.0 / self.max_rate
            if self.min_rate <= 0.0:
                self.max_interval = 0.0
            else:
                self.max_interval = 1.0 / self.min_rate

        # Start actual test
        sub = rospy.Subscriber(topic, rospy.AnyMsg, self.callback)
        self.assert_(not self.errors, "bad initialization state (errors)")
        
        print("Waiting for messages")
        # we have to wait until the first message is received before measuring the rate
        # as time can advance too much before publisher is up
        
        # give the test wait_time seconds to start
        wallclock_timeout_t = time.time() + wait_time
        while not self.message_received and time.time() < wallclock_timeout_t:
            time.sleep(0.1)
        if hz > 0.:
            self.assert_(self.message_received, "no messages before timeout")
        else:
            self.failIf(self.message_received, "message received")
            
        print("Starting rate measurement")
        if self.wall_clock:
            timeout_t = time.time() + test_duration
            while time.time() < timeout_t:
                time.sleep(0.1)
        else:
            timeout_t = rospy.get_time() + test_duration
            while rospy.get_time() < timeout_t:
                rospy.sleep(0.1)
        print("Done waiting, validating results")
        sub.unregister()

        # Check that we got at least one message
        if hz > 0:
            self.assert_(self.msg_count > 0, "no messages received")
        else:
            self.assertEquals(0, self.msg_count)
        # Check whether inter-message intervals were violated (if we were
        # checking them)
        self.assert_(not self.errors, '\n'.join(self.errors))

        # If we have a non-zero rate target, make sure that we hit it on
        # average
        if hz > 0.0:
          self.assert_(self.msg_t0 >= 0.0, "no first message received")
          self.assert_(self.msg_tn >= 0.0, "no last message received")
          dt = self.msg_tn - self.msg_t0
          self.assert_(dt > 0.0, "only one message received")
          rate = ( self.msg_count - 1) / dt
          self.assert_(rate >= self.min_rate, 
                       "average rate (%.3fHz) exceeded minimum (%.3fHz)" %
                       (rate, self.min_rate))
          self.assert_(rate <= self.max_rate, 
                       "average rate (%.3fHz) exceeded maximum (%.3fHz)" %
                       (rate, self.max_rate))
        
    def callback(self, msg):
        # flag that message has been received
        self.message_received = True         
        try:
            self.lock.acquire()

            if self.wall_clock:
                curr = time.time()
            else:
                curr_rostime = rospy.get_rostime()
                
                if curr_rostime.is_zero():
                    return
                curr = curr_rostime.to_sec()
 
            if self.msg_t0 <= 0.0 or self.msg_t0 > curr:
                self.msg_t0 = curr
                self.msg_count = 1
                last = 0
            else:
                self.msg_count += 1
                last = self.msg_tn

            self.msg_tn = curr

            # If we're instructed to check each inter-message interval, do
            # so
            if self.check_intervals and last > 0:
                interval = curr - last
                if interval < self.min_interval:
                    print("CURR", str(curr), file=sys.stderr)
                    print("LAST", str(last), file=sys.stderr)
                    print("msg_count", str(self.msg_count), file=sys.stderr)
                    print("msg_tn", str(self.msg_tn), file=sys.stderr)
                    self.errors.append(
                        'min_interval exceeded: %s [actual] vs. %s [min]'%\
                        (interval, self.min_interval))
                # If max_interval is <= 0.0, then we have no max
                elif self.max_interval > 0.0 and interval > self.max_interval:
                    self.errors.append(
                        'max_interval exceeded: %s [actual] vs. %s [max]'%\
                        (interval, self.max_interval))

        finally:
            self.lock.release()
    
        
if __name__ == '__main__':
    # A dirty hack to work around an apparent race condition at startup
    # that causes some hztests to fail.  Most evident in the tests of
    # rosstage.
    time.sleep(0.75)
    try:
        rostest.run('rostest', NAME, HzTest, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")

        
