import random

CV_C = (0, [3,6], [(0, 6, 'e'), (0, 1, 'c'), (1, 2, 'v'),
         (2, 3, '.'), (2, 4, 'c'), (4, 3, '.'), (3, 1, 'c')])

almaTheMachine = (0, [3], [(0,1,'a',1.0),
                           (1,1,'l',0.2),
                           (1,2,'m',0.8),
                           (2,2,'m',0.7),
                           (2,3,'a',0.3)])

timmyTheMachine = ('S', ['S'], [('S','N','b'),
                                ('N','S','a'),
                                ('N','D','a'),
                                ('D','S','b')])

######
# Finite state stress transducers

# '0' = unstressed syllable, '1' = 2ndary stress, '2' = primary stress

# Alternating (2ndary) stress from left edge
#        start finals  arcs
LRAlt = (0,    [0,1],  [(0, 1, {'0':'1'}), # 0 -> 1 (0:1)
                        (1, 0, {'0':'0'})])# 1 -> 0 (0:0)

# Turn the leftmost 2ndary stress into primary stress
LeftMain = (0, [0,1], [(0, 0, {'0':'0'}), # 0 -> 0 (0:0)
                       (0, 1, {'1':'2'}), # 0 -> 1 (1:2)
                       (1, 1, {'0':'0', '1':'1'})]) # 1 -> 1 (0:0,1:1)

# Alternating 2ndary stress from right edge
RLAlt = (0, [0,2], [(0, 2, {'0':'1'}),
                    (0, 1, {'0':'0'}),
                    (1, 2, {'0':'1'}),
                    (2, 1, {'0':'0'})])

# Rightmost 2ndary stress becomes primary
RightMain = (0, [0, 2], [(0, 0, {'0':'0'}),
                         (0, 2, {'1':'2'}),
                         (0, 1, {'1':'1'}),
                         (1, 1, {'0':'0', '1':'1'}),
                         (1, 2, {'1':'2'}),
                         (2, 2, {'0':'0'})])

def transduceDetComp(inputWord, fsm):
    #currentStates = set([fsm[0]])
    currentState = fsm[0]
    output = '' # ouptut gets built up in this
    for symbol in inputWord:
        for (origin, dest, mapping) in fsm[2]:
            if origin == currentState and symbol in mapping:
                output += mapping[symbol]
                currentState = dest
                break
    # we assume that the fsm is complete (there was always an arc out of every
    # state) and that every state is final
    return output

def transduceNondet(inputWord, fsm):
    # The fsm might be nondeterministic or incomplete
    pathsSoFar = [[(fsm[0], '')]] # we keep track of paths thru fsm here
    for symbol in inputWord:
        newPaths = []
        for path in pathsSoFar:
            currentState = path[-1][0]
            for (origin, dest, mapping) in fsm[2]:
                if origin == currentState and symbol in mapping:
                    newPaths.append(path+[(dest, mapping[symbol])])
        pathsSoFar = newPaths

    # Build up the set/list of possible outputs corresponding to
    # each of the possible paths through the machine for the given inputWord
    outputs = []
    for path in pathsSoFar:
        #print path
        if path[-1][0] not in fsm[1]:
            continue
        output = ''.join([symbol for (state, symbol) in path])
        outputs.append(output)
        
    return outputs

def doPhonology(inputWord, rules):
    pass

def inputFor(outputWord, fsm):
    # Reverse Transduction
    # The fsm might be nondeterministic or incomplete
    pathsSoFar = [[(fsm[0], '')]] # we keep track of paths thru fsm here
    for symbol in outputWord:
        newPaths = []
        for path in pathsSoFar:
            currentState = path[-1][0]
            for (origin, dest, mapping) in fsm[2]:
                if origin == currentState and symbol in mapping.values():
                    for inputSymbol in mapping:
                        if mapping[inputSymbol] == symbol:
                            newPaths.append(path+[(dest, inputSymbol)])
        pathsSoFar = newPaths

    # Build up the set/list of possible inputs corresponding to
    # each of the possible paths through the machine for the given outputWord
    inputs = []
    for path in pathsSoFar:
        #print path
        if path[-1][0] not in fsm[1]:
            continue
        inputWord = ''.join([symbol for (state, symbol) in path])
        inputs.append(inputWord)
        
    return inputs

# HOMEWORK HERE
def composeDet(fsm1, fsm2):
    # returns a new fsm that is the "composition" of fsm1 and fsm2,
    # i.e., it applies both fsm1 and fsm2 simultaneously. Create a
    # new machine whose states correspond to pairs of states in fsm1
    # and fsm2.
    # Do for homework, yo.
    pass

def recognize(fsm, sentence):
    # fsm = (startstate, list-of-finals, list-of-arcs)
    # an arc is a (origin,terminus,symbol)-tuple.
    currentStates = set([fsm[0]])
    for symbol in sentence:
        nextStates = set()
        for (origin, terminus, x) in fsm[2]:
            if origin in currentStates and x == symbol:
                # we found an arc for this symbol!
                nextStates.add(terminus)
        if not nextStates:
            # we never found an arc for this symbol
            return False
        currentStates = nextStates
        
    # we made it all the way to the last symbol
    # but are we in a final state?
    if currentStates & set(fsm[1]):
        return True
    else:
        return False

def generate(fsm):
    sentence = []
    currentState = fsm[0]
    while currentState != 'XYZZY':
        arcsOut = []
        for (origin, terminus, x) in fsm[2]:
            if origin == currentState:
                # we found an arc
                arcsOut.append( (origin,terminus,x) )
        if currentState in fsm[1]:
            arcsOut.append('XYZZY')
        arc = random.choice(arcsOut)
        if arc == 'XYZZY':
            currentState = arc
        else:
            currentState = arc[1]
            sentence.append(arc[2])

    return ''.join(sentence)

def pickone(arcs):
    x = random.random()
    for arc in arcs:
        x -= arc[3]
        if x <= 0.0:
            return arc
    print arcs
    raise ValueError('Bad probabilities!')

def generateProb(fsm):
    sentence = []
    currentState = fsm[0]
    while 1:
        arcsOut = []
        for (origin, terminus, x, prob) in fsm[2]:
            if origin == currentState:
                # we found an arc
                arcsOut.append( (origin,terminus,x,prob) )

        if not arcsOut and currentState in fsm[1]:
            break
        elif not arcsOut:
            return 'CRASHED DERIVATION'

        arc = pickone(arcsOut)
        currentState = arc[1]
        sentence.append(arc[2])

    return ''.join(sentence)

# weighted recognizer -- returns the probability of the given
# sentence according to the given weighted fsm (a la alma)
def recognizeProb(wfsm, sentence):
    probSoFar = 1.0
    currentState = wfsm[0]
    for symbol in sentence:
        #print symbol, currentState
        for (origin, dest, arcSymbol, prob) in wfsm[2]:
            if arcSymbol == symbol and currentState == origin:
                probSoFar *= prob
                currentState = dest
                break

    if currentState in wfsm[1]:
        return probSoFar
    return 0.0

def recognizeProbNondet(wfsm, sentence):
    pathsSoFar = [[(wfsm[0], 1.0, '')]]
    for symbol in sentence:
        newPaths = []
        for path in pathsSoFar:
            currentState = path[-1][0]
            for (origin, dest, sym, prob) in wfsm[2]:
                if origin == currentState and sym == symbol:
                    newPaths.append(path+[(dest,prob,symbol)])
        pathsSoFar = newPaths

    prob = 0.0
    for path in pathsSoFar:
        #print path
        if path[-1][0] not in wfsm[1]:
            continue
        pathProb = 1.0
        for (state,sprob,sym) in path:
            pathProb *= sprob
        prob += pathProb
    return prob#, pathsSoFar

def stringsUpto(alphabet, length):
    r = ['']
    newlist = []
    for l in range(length):
        for symbol in alphabet:
            for sentence in r:
                newlist.append(sentence + symbol)
        r += newlist
        newlist = []
        #print newlist
    return r

header = \
"""
digraph agent_network {
	rankdir=LR;
	size="8,5"
	node [shape = circle];
%s
%s [style=filled];
%s
}
"""

def printNetwork(fsm):
    arcs = []
    for src, dest, label in fsm[2]:
        arcs.append('%s -> %s [label=%s];' % (src, dest, label))
    finals = ['%s [peripheries=2];' % (f,) for f in fsm[1]]
    graph = header % ('\n'.join(arcs), fsm[0], '\n'.join(finals))
    return graph
