import nltk, csv, random

def austats(corpusName):
    corpus = nltk.corpus.gutenberg.raw(corpusName)
    lines = corpus.split('\n')
    dic = {}
    for line in lines:
        for wd in line.split():
            wd = cleanWord(wd)
            if wd not in dic: dic[wd] = 0
            dic[wd] += 1
    totalWords = float(sum(dic.values()))
    ranks = sorted([(cnt,wd) for (wd,cnt) in dic.items()],reverse=True)
    R = []
    for idx in range(len(ranks)):
        R.append((ranks[idx][1], ranks[idx][0] / totalWords, idx+1))
    return R

def corpusNgrams(corpusName, n):
    result = {}
    words = nltk.corpus.gutenberg.words(corpusName)
    previousN = []
    for word in words:
        previousN.append(word)
        if len(previousN) > n:
            del previousN[0]
            gram = tuple(previousN)
            if gram not in result: result[gram] = 0
            result[gram] += 1

    # total = float(sum(result.values()))
    # for gram in result: result[gram]/= total
    return result

janetext = ['austen-emma','austen-persuasion','austen-sense']
willtext = ['shakespeare-caesar', 'shakespeare-hamlet' ,'shakespeare-macbeth']

def conditionalNgrams(corpusNames, n):
    '''Builds a dictionary from n-1 grams to dicts from words to counts.''' 
    result = {}
    for corpusName in corpusNames:
        print 'processing',corpusName
        words = nltk.corpus.gutenberg.words(corpusName)
        previousN = []
        for word in words:
            previousN.append(word)
            if len(previousN) > n:
                del previousN[0]
                gram = ' '.join(previousN[:-1])
                if gram not in result:
                    result[gram] = {}
                if word not in result[gram]:
                    result[gram][word] = 0
                result[gram][word] += 1
    return result    


def pickone(cdict):
    tot = sum(cdict.values())
    n = random.randint(1,tot)
    for k in cdict:
        n -= cdict[k]
        if n <= 0: return k 


def janerateText(gramDict,tLen):
    firstProbs = dict([(gram, sum(gramDict[gram].values()))
                       for gram in gramDict
                       if gram[0] in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'])
    first = pickone(firstProbs)
    gram = first.split()
    ans = [] + gram
    while tLen > 0 or ans[-1] not in '.!?' :
        tLen -= 1
        w = pickone(gramDict[' '.join(gram)])
        ans.append(w)
        gram = gram[1:] + [w]
    return ans


def randomGram(ngrams):
    #assert(sum(ngrams.values()))
    print ngrams
    num = random.random()* sum([freq for gram,freq in ngrams])
    for (gram, freq) in ngrams:
        num -= freq
        if num < 0.0:
            return gram

def generateText(ngrams, length):
    text = []
    slidingWindow = list(randomGram(ngrams))
    while length > 0:
        potentialGrams = [(x, y) for (x, y) in ngrams.items() \
                          if x[1:] == slidingWindow[:-1]]
        gram = randomGram(potentialGrams)
        length -= len(gram)
        text.append(' '.join(gram))
    return ' '.join(text)

def cleanWord(word):
    return word.lower().strip("',.;:\"/\\[]{}!?#$%^&*()")

def writeStats(fname, stats):
    f = open(fname, 'w')
    wri = csv.writer(f)
    wri.writerow(['Word', 'Frequency', 'Rank'])
    for row in stats:
        wri.writerow(row)
    f.close()

def otheraustats():
    
    dic = {}
    for wd in words:
        if wd not in dic: dic[wd] = 0
        dic[wd] += 1
    ranks = sorted([(cnt,wd) for (wd,cnt) in dic.items()],reverse=True)
    R = []
    for idx in range(len(ranks)):
        R.append((ranks[idx][1],ranks[idx][0],idx+1))
    return R
    
from pylab import *

def graphRankVsFreq(corpus):
    stats = austats(corpus)[:100]
    ranks = [rank for (word, freq, rank) in stats]
    freqs = [freq for (word, freq, rank) in stats]
    plot(ranks, freqs, linewidth=1.0)

    xlabel('Rank')
    ylabel('Frequency')
    title('Rank vs Frequency')
    grid(True)
    show()

def barChart(corpus):
    stats = austats(corpus)[:10]
    freqs = [f for w,f,c in stats]
    words = [w for w,f,c in stats]
    
    p1 = bar(range(10), freqs)
    xticks(range(10), words)
