#!/usr/bin/python
# Runs a test manifest that exchanges ACS and bundles with a bundle agent.
#
#	Author: Andrew Jenkins
#				University of Colorado at Boulder
#	Copyright (c) 2008-2011, Regents of the University of Colorado.
#	This work was supported by NASA contracts NNJ05HE10G, NNC06CB40C, and
#	NNC07CB47C.		
#
usage = """\
usage: %prog [options] <manifest file name>
    Tests whether a bundle protocol agent correctly handles Aggregate
    Custody Signals by sending it bundles and ACS and verifying the
    ACS received from the bundle agent and peeking in its custody
    database.

    The test steps to perform are specified in a "manifest" file which
    is written in python.  The manifest file should use these functions
    for interacting with the bundle agent:
         send, expectacs, sendexpectacs, sleep, receive, receiveall

    [options] are flag options.
    <manifest file name> is the name of a manifest file."""

from optparse import OptionParser
import shlex
import sys
from bpacs import BpAcs, dtntime, unserialize_acs
from bpacsdb import BpAcsDb
from bplist import bplist
from acslist import acslist
from bundle import Bundle
from sdnv import sdnv_decode, sdnv_encode
from cteb import Cteb, CTE_BLOCK
import struct
import bp
import pdb
import time
import socket
import traceback

# The manifest is the filename of this tests' procedure, written in python.
manifest = None

# By default, we use this as the destination EID in bundles we send.
destEid = "ipn:3.5"

# We exit with this code.
has_failed = 0

parser = OptionParser(usage)

def parse_hostport(addrportstring, defaultHost = "localhost", defaultPort = 4556):
    hoststring = defaultHost
    portstring = str(defaultPort)
    try:
        (hoststring, portstring) = addrportstring.split(":")
    except ValueError:
        (portstring) = addrportstring
    except AttributeError:
        pass
    return (hoststring, int(portstring))
    


def parse_args():
    global parser, manifest, receiver, sender
   
    parser.add_option("-d", "--destination", dest="dest", 
        help="Send bundles to the BPA-under-test at [host:]port.")
    parser.add_option("-s", "--source", dest="source", 
        help="Receive bundles from the BPA-under-test at [host:]port.")
    parser.add_option("-e", "--errexit", action="store_true", dest="errexit",
        help="Exit immediately on error", default=False)
    
    # Parse options.
    (options, args) = parser.parse_args()

    # Open the test manifest
    if len(args) > 0:
        manifest = args[0]
    else:
        parser.print_help()
        sys.exit(1)

    # Make a socket to send 
    sender = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sender.connect( parse_hostport(parser.values.dest, "localhost", 4556) )
   
    # Bind to a socket to receive bundles from bundle protocol agent.
    # bind( ("", port)) means "INADDR_ANY:port"
    receiver = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    receiver.bind( parse_hostport(parser.values.source, "", 4557) )

parse_args()

expected_acs = { True : { }, False : { } }
expected_custody = { }

# The custody database for the test-acs "Bundle Agent."
my_acsdb = BpAcsDb()
# The custody database of the bundle agent under test.
ut_acsdb = BpAcsDb()

class TestFailureError(Exception):
    pass

class BundleNotReceivedError(TestFailureError):
    pass

class ManifestSyntaxError(Exception):
    pass

class Tester:
    def __init__(self):
        self.sentbundles = 0
        self.failed = 0
        self.commands = { }
        for i in [ "send", "expectacs", "sendexpectacs", "expectcustody",
                   "sendexpectcustody", "sendexpectacsandcustody", 
                   "sleep", "receive", "receiveall", "receiveuntil",
                   "checkcustody", "clearcustody", "checkacslist",
                   "assertutcid", "emptycustody", "sendacs", "dtntime" ]:
            self.commands[i] = getattr(self, i)
    
    def try_split_seq(self, seq):
        try:
            return tuple(int(seq.split(".")[0]) + map(int,seq.split(".")[1].split("-")))
        except IndexError:
            return tuple([int(seq)])
        except AttributeError:
            return tuple([seq])
        except ValueError:
            raise ManifestSyntaxError("""sequence number "%s" not an int or int.int-int""" % seq)

    def send_bytes_to_ba(self, bytes):
        # Try to send the bytes up to twice.  Sometimes, python will give
        # you a socket error the first time without even trying to send,
        # like if it receives an "ICMP unreachable" error.  Even though this
        # usually means the test isn't going to work, ignore it and send
        # anyways.
        try:
            sender.send(bytes)
        except socket.error:
            try:
                sender.send(bytes)
            except socket.error, se:
                print "Couldn't send to %s: %s" % (sender.getpeername(), se)


    def send(self, custodian, source, timestamp, seq, custodial = None, cid = None, ttl = None):
            flags = bp.BUNDLE_SINGLETON_DESTINATION
            if custodial == True or custodial == None:
                flags |= bp.BUNDLE_CUSTODY_XFER_REQUESTED
            if not ttl:
                ttl = 86400*5*366
            timestamp = dtntime(timestamp)
            seq = int(seq)
            self.sentbundles += 1
            bundlePayload = "Test bundle seq %d" % (seq)
            bundle = Bundle(flags,
                            bp.COS_BULK, 0, source, destEid, 
                            "dtn:none", custodian, "dtn:none", timestamp, seq, ttl, bundlePayload)
            bundleBytes =  bp.encode(bundle)
            if cid == None:
                cid = my_acsdb.getCustodyId((source, timestamp, seq))
            cteb_payload = sdnv_encode(cid) + custodian
            bundleBytes += bp.encode_block_preamble(10, bp.BLOCK_FLAG_REPLICATE, [], len(cteb_payload))
            bundleBytes += cteb_payload
            bundleBytes += bp.encode_block_preamble(bp.PAYLOAD_BLOCK, bp.BLOCK_FLAG_LAST_BLOCK, 
                                [], len(bundlePayload))
            bundleBytes += bundlePayload

            self.send_bytes_to_ba(bundleBytes)

    def expectacs(self, source, timestamp, seq, successful = True, reason = None, cid = None):
            global my_acsdb
            global expected_acs
            if cid == None:
                cid = my_acsdb.getCustodyId((source, dtntime(timestamp), seq))
            if not expected_acs[successful].has_key(reason):
                expected_acs[successful][reason] = BpAcs(succeeded = successful,
                                                         reason = reason)
            expected_acs[successful][reason].add(cid)
    
    def sendexpectacs(self, custodian, source, timestamp, seq, successful = True, reason = None, cid = None):
            self.send(custodian, source, timestamp, seq, cid = cid)
            self.expectacs(source, timestamp, seq, successful, reason, cid = cid)

    def expectcustody(self, source, timestamp, seq):
        global expected_custody
        expected_custody[(source, dtntime(timestamp), seq)] = None

    def sendexpectcustody(self, custodian, source, timestamp, seq, cid = None):
            self.send(custodian, source, timestamp, seq, cid = cid)
            self.expectcustody(source, timestamp, seq)

    def sendexpectacsandcustody(self, custodian, source, timestamp, seq, cid = None, ttl = None):
            self.send(custodian, source, timestamp, seq, cid = cid, ttl = ttl)
            self.expectacs(source, timestamp, seq, True, None, cid = cid)
            self.expectcustody(source, timestamp, seq)

    def sleep(self, sleepduration):
            time.sleep(float(sleepduration))

    def receive(self, timeout = None, sizelimit = None):
        global expected_acs
        # Get the timeout if the user specified
        try:
            (timeout) = float(timeout)
        except TypeError:
            timeout = None
        except ValueError:
            timeout = None

        # Decode bundle.
        try:
            receiver.settimeout(timeout)
            bundleBytes = receiver.recv(64*1024)
        except AttributeError:
            raise TestFailureError("Receiver socket not created")
        except socket.timeout:
            raise BundleNotReceivedError("Didn't receive a bundle within %f seconds on %s" % (timeout, receiver.getsockname()))
        (bundle, bundle_length, remainder) = bp.decode(bundleBytes)
        
        # If the bundle isn't administrative, extract the CTEB emitted by the BPA-under-test.
        if (bundle.bundle_flags & bp.BUNDLE_IS_ADMIN) == 0:
            if bundle.blocks.has_key(CTE_BLOCK):
                cteb = Cteb(bundle)
                ut_acsdb.setCustodyId(bp.get_bundle_id(bundle), cteb.custodyId)
            return

        # If the bundle is a normal custody signal, ignore it.
        if (bundle.bundle_flags & bp.BUNDLE_IS_ADMIN) != 0 \
           and struct.unpack("!B", bundle.blocks[bp.PAYLOAD_BLOCK][0]["payload"][0])[0] & 0xf0 != 0x40:
                return

        if sizelimit != None and bundle.blocks.has_key(bp.PAYLOAD_BLOCK):
            payload_len = len(bundle.blocks[bp.PAYLOAD_BLOCK][0]["payload"]) 
            if payload_len > sizelimit:
                raise TestFailureError("Payload was bigger than %d" % payload_len)
        
        acs = unserialize_acs(bundle.blocks[bp.PAYLOAD_BLOCK][0]["payload"])

        if acs == None:
            raise TestFailureError("Received an unparseable ACS : %s" % bundle)

        # Look up the bundle with the corresponding (success, reason) we're expecting.
        try:
            toCompare = expected_acs[bool(acs.succeeded)][acs.reason]
        except KeyError:
            raise TestFailureError("Received an unexpected ACS for (success: %s, reason: %s): %s; %s" % (bool(acs.succeeded), acs.reason, acs, bundle))

        # Before we compare, make our ACS have the same signal time as expected_acs.
        acs.signalTime = toCompare.signalTime

        # Compare received acs to the one we're expecting; error if not equal.
        if (acs != toCompare):
            raise TestFailureError("Expected %s, got %s" % (toCompare, acs))

        # Reset the ACS we're expecting for this (success, reason) to empty.
        del expected_acs[bool(acs.succeeded)][acs.reason]
    
    def receiveall(self, timeout = None, sizelimit = None):
        while True:
            notempty = False
            for i in True, False:
                if len(expected_acs[i]) > 0:
                    notempty = True
                    break
            if notempty == False:
                return
            try:
                self.receive(timeout, sizelimit)
            except BundleNotReceivedError, bnre:
                raise TestFailureError("Didn't receive all expected bundles within %d, remaining expected ACS: %s" % (timeout, expected_acs))

    def receiveuntil(self, timeout):
        start = time.time()
        while(time.time() < (start + timeout)):
            try:
                self.receive( (start + timeout) - time.time())
            except BundleNotReceivedError:
                # Receiving until timeout has passed, so its OK if no bundle.
                pass
    
    def checkcustody(self):
        global expected_custody
        my_expected_custody = expected_custody.copy()
        bundles_in_custody = bplist()
        if bundles_in_custody == None:
            if len(my_expected_custody) == 0:
                return
            else:
                raise TestFailureError("No bundles in custody, but expected %s" % my_expected_custody)
        for i in bundles_in_custody:
            try:
                del my_expected_custody[(i["sourceEid"], i["timestamp"], i["count"])]
            except KeyError:
                raise TestFailureError("ION has custody of unexpected bundle %s" % i)
        if len(my_expected_custody) > 0:
            raise TestFailureError("ION didn't have custody of bundles %s" % my_expected_custody)
    
    def clearcustody(self, custodianEid):
        global expected_custody
        acs = BpAcs()
        
        if len(expected_custody) == 0:
            raise TestFailureError("No custody to clear; try expectcustody")

        self.sendacs(custodianEid, expected_custody.keys())
        self.emptycustody()

    def emptycustody(self):
        global expected_custody
        expected_custody = { }
        
    def sendacs(self, custodianEid, bids):
        acs = BpAcs()
        for i in bids:
            cid = ut_acsdb.getCustodyId(i, dontCreate = True)
            if cid == None:
                raise TestFailureError("Can't clear custody of %s: didn't get identifying CTEB.", i)
            acs.add(cid)
    
        bundlePayload = acs.serialize()
        bundle = Bundle(bp.BUNDLE_SINGLETON_DESTINATION | bp.BUNDLE_IS_ADMIN, 
                        bp.COS_NORMAL, 0, "ipn:120.0", custodianEid, 
                        "dtn:none", "dtn:none", "dtn:none", time.time() - bp.TIMEVAL_CONVERSION, 
                        None, 86400*5*366, bundlePayload)
        bundleBytes =  bp.encode(bundle)
        bundleBytes += bp.encode_block_preamble(bp.PAYLOAD_BLOCK, bp.BLOCK_FLAG_LAST_BLOCK, 
                            [], len(bundlePayload))
        bundleBytes += bundlePayload
        
        self.send_bytes_to_ba(bundleBytes)

        # Give ION time to process the bundle.
        self.sleep(1)
    
    def checkacslist(self, verbose = False, expectedCount = None):
        (cbids, mismatches) = acslist()
        if verbose != False:
            print repr(cbids)
        if len(mismatches) != 0:
            raise TestFailureError("ACS ID database mismatches: %s" % repr(mismatches))
        if expectedCount != None and len(cbids) != expectedCount:
            raise TestFailureError("There are %d CBIDs but test expected %d" % (len(cbids), expectedCount))
    
    def assertutcid(self, lowestutcid = None, highestutcid = None):
        cids = ut_acsdb.cid.keys()
        if lowestutcid != None and min(cids) < lowestutcid:
            raise TestFailureError("Minimum CID is %d (< %d)" % (min(cids), lowestutcid))
        if highestutcid != None and max(cids) > highestutcid:
            raise TestFailureError("Maximum CID is %d (> %d)" % (max(cids), highestutcid))

    def dtntime(self, time_as_str):
        return dtntime(time_as_str)


tester = Tester()

try:
    execfile(manifest, tester.commands)
except TestFailureError, tfe:
    ei = sys.exc_info()[2]
    print """FAIL (line %d of %s): %s""" % (ei.tb_next.tb_lineno, manifest, tfe)
    tester.failed = 1
    if parser.values.errexit:
        sys.exit(tester.failed)

sys.exit(tester.failed)
