import random

CV_C = (0, [3,6], [(0, 6, 'e'), (0, 1, 'c'), (1, 2, 'v'),
         (2, 3, '.'), (2, 4, 'c'), (4, 3, '.'), (3, 1, 'c')])

almaTheMachine = (0, [3], [(0,1,'a',1.0),
                           (1,1,'l',0.2),
                           (1,2,'m',0.8),
                           (2,2,'m',0.7),
                           (2,3,'a',0.3)])

timmyTheMachine = ('S', ['S'], [('S','N','b'),
                                ('N','S','a'),
                                ('N','D','a'),
                                ('D','S','b')])

def recognize(fsm, sentence):
    # fsm = (startstate, list-of-finals, list-of-arcs)
    # an arc is a (origin,terminus,symbol)-tuple.
    currentStates = set([fsm[0]])
    for symbol in sentence:
        nextStates = set()
        for (origin, terminus, x) in fsm[2]:
            if origin in currentStates and x == symbol:
                # we found an arc for this symbol!
                nextStates.add(terminus)
        if not nextStates:
            # we never found an arc for this symbol
            return False
        currentStates = nextStates
        
    # we made it all the way to the last symbol
    # but are we in a final state?
    if currentStates & set(fsm[1]):
        return True
    else:
        return False

def generate(fsm):
    sentence = []
    currentState = fsm[0]
    while currentState != 'XYZZY':
        arcsOut = []
        for (origin, terminus, x) in fsm[2]:
            if origin == currentState:
                # we found an arc
                arcsOut.append( (origin,terminus,x) )
        if currentState in fsm[1]:
            arcsOut.append('XYZZY')
        arc = random.choice(arcsOut)
        if arc == 'XYZZY':
            currentState = arc
        else:
            currentState = arc[1]
            sentence.append(arc[2])

    return ''.join(sentence)

def pickone(arcs):
    x = random.random()
    for arc in arcs:
        x -= arc[3]
        if x <= 0.0:
            return arc
    print arcs
    raise ValueError('Bad probabilities!')

def generateProb(fsm):
    sentence = []
    currentState = fsm[0]
    while 1:
        arcsOut = []
        for (origin, terminus, x, prob) in fsm[2]:
            if origin == currentState:
                # we found an arc
                arcsOut.append( (origin,terminus,x,prob) )

        if not arcsOut and currentState in fsm[1]:
            break
        elif not arcsOut:
            return 'CRASHED DERIVATION'

        arc = pickone(arcsOut)
        currentState = arc[1]
        sentence.append(arc[2])

    return ''.join(sentence)

# weighted recognizer -- returns the probability of the given
# sentence according to the given weighted fsm (a la alma)
def recognizeProb(wfsm, sentence):
    probSoFar = 1.0
    currentState = wfsm[0]
    for symbol in sentence:
        #print symbol, currentState
        for (origin, dest, arcSymbol, prob) in wfsm[2]:
            if arcSymbol == symbol and currentState == origin:
                probSoFar *= prob
                currentState = dest
                break

    if currentState in wfsm[1]:
        return probSoFar
    return 0.0

def recognizeProbNondet(wfsm, sentence):
    pathsSoFar = [[(wfsm[0], 1.0, '')]]
    for symbol in sentence:
        newPaths = []
        for path in pathsSoFar:
            currentState = path[-1][0]
            for (origin, dest, sym, prob) in wfsm[2]:
                if origin == currentState and sym == symbol:
                    newPaths.append(path+[(dest,prob,symbol)])
        pathsSoFar = newPaths

    prob = 0.0
    for path in pathsSoFar:
        #print path
        if path[-1][0] not in wfsm[1]:
            continue
        pathProb = 1.0
        for (state,sprob,sym) in path:
            pathProb *= sprob
        prob += pathProb
    return prob#, pathsSoFar

def stringsUpto(alphabet, length):
    r = ['']
    newlist = []
    for l in range(length):
        for symbol in alphabet:
            for sentence in r:
                newlist.append(sentence + symbol)
        r += newlist
        newlist = []
        #print newlist
    return r

header = \
"""
digraph agent_network {
	rankdir=LR;
	size="8,5"
	node [shape = circle];
%s
%s [style=filled];
%s
}
"""

def printNetwork(fsm):
    arcs = []
    for src, dest, label in fsm[2]:
        arcs.append('%s -> %s [label=%s];' % (src, dest, label))
    finals = ['%s [peripheries=2];' % (f,) for f in fsm[1]]
    graph = header % ('\n'.join(arcs), fsm[0], '\n'.join(finals))
    return graph
