Machine Learning and Poems

"She played out to me
to me!!!
in those two years
till sometimes dropping
from the devil
foe, I triumphed!!!
and singing the least some one!
big words that cover
"
                                               - By ThinkPad Lenovo X220, 2012-?

A poem is a work of art. But it can also be seen as a work of science! This post is about teaching a computer in < 600 lines of code how to create English poems. Well, poems in general are abstract entities, so if the computer is going to learn how to write poems, it will have to learn that poems are most of the time deeply abstract, and deeply complex to understand :) To a degree, that makes our lives easier.

Ok, so how do we go about teaching a computer how to write somewhat valid (albeit abstract!) poems? We'll use Machine Learning.
Machine Learning consists of building a model representing what you want the computer (the machine) to learn, and then instructing the computer to tell us what it learned off of that model. There are a number of Machine Learning techniques that can be used - the one that I decided to implement here is Markov Chains. Markov Chains are nothing but a Finite State Machine where the connections between two nodes contains a probability depicting how strongly or weakly connected those two states are. The probability is simply based on how many times that connection has been "seen" in the training data. Yes, sorry: the training data - in order to create the Markov Chain model, we need to supply the program with a set of valid poems. We'll use the web to grab that set.

So the approach will be 3-fold:
1) Crawl the web and download a hell lot of poems
2) Build the Markov Chain based on #1
3) Traverse and Markov Chain based on #2, and voila, create the perfect-ish poem!

Ok, step #1. A quick Bing search for poems gives us a nice first web hit to work with: http://www.poemhunter.com/. There are hundreds of poems there, whether one understands it or not. So the crawler will work in this way:
a) Go here: http://www.poemhunter.com/
b) Download the content
c) Go to the link that says "top 500 poems". Actually, they have 600.
d) Go link by link, parse the HTML, download the content
e) Push to a local file. Do a lot of cleanup/massage/parsing of the data to remove weirdness. Regular expressions would be more appropriate here, but that's ok.

Great, #1 is done, now we have 600 poems to work with, such as:

"
I came  as  tomorrow swaddled in innocence to your  warm  womb mother without  your  choice or mine destined to up date with time our human tree but before  love grew into flesh  and words what is  unfinished creation a precipitation of blood became  my transcendence.
"

Great little poem (I guess...). Any who, step #2 - building the Markov Chain. We will look at all the word connections in the poems. We see a connection, we increase the counter for that connection. The higher the counter, the more likely that connection will happen in our perfect poem. I'll give you an example, if we look for the word "history" in the list of poems we downloaded, we'll find the following words right after history, with the respective counter of how many times that connection has been seen:

history:
  we: 1
  in: 1
  is: 2
  with: 1
  of: 1

Hence, if we see the word "history" in our generated poem, what's the word that will come right after? Well, the probabilities are as follow (modulo some rounding errors):

history:
  we: 16.66%
  in: 16.66%
  is: 33.33%
  with: 16.66%
  of: 16.66%

Awesome, so that's what we do for step #2. Implementation-wise, we'll use a combination of Hash tables to store this data. Notice that one limitation of Markov Chains is that the information is non-contextual (similar to context-free grammars): it doesn't not matter how we got to a certain word - the probability from that word to the next is independent of the "past", only the current state (word) matters. That will come handy if we believe that the poems we're generating have a "deep philosophical touch" in them :)

Now for the grand finale!! Step #3! That becomes straightforward...ish... First things first: build a bucket system, which is easy: for which word, create a bucket system by cumulatively adding the connection counters. Back to the "history" word above, this is what you do:

You go from this:

history:
  we: 1
  in: 1
  is: 2
  with: 1
  of: 1

To this:

history:
  we: 1
  in: 2
  is: 4
  with: 5
  of: 6

Ok, now we have a range, from [1..6]. You toss a coin up, get a number from [1..6], and then pick the word based on the proper bucket. The buckets are:

history:
  we: 1..1
  in: 2..2
  is: 3..4
  with: 5..5
  of: 6..6

As you can see, "is" has the highest chance of being picked. Ok, now we need the "seed" word, the initial one - let the user give you the seed. Finally, traverse the chain, stochastically traversal, and decide how many lines your poem will have, and how many words per line. Once you output the poem, add the proper decoration with punctuation signs and further line breakers - and that's it!!!

Some pearls that my laptop generated :)

Seed: "end"

"End in my passion
in a chorus cheek
and loblolly
and a horse
my days, my mind
from heaven knows
it was cold"


Seed: "he"

"He kissed me
to a wind blew
the breath is wrong page
homeless, tempests
to the air it, to be dried
still, can't trust"


Seed: "sadness"

"Sadness, your face
with the things that lets the boatmen row
of such rate
should end of death
has become in the words: one"


Seed: "history"

"History is in the rain
at may
he only got his arms
and fulfillment
a shred of the glittering bubble.
Reputation for all. Hollow"


Seed: "world"

"World, after all
the stars o'erhead were wed
where oil had passed
around a ribboned coat
when it had nothing else
to strive to"


Code, in < 600 lines of C#:

Program.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.IO;
namespace MachineLearningAndPoems
{
    class Program
    {
        static void Main(string[] args)
        {
            if (args.Length != 2)
            {
                Console.WriteLine("MachineLearningAndPoems.exe <poem file> <first word of the poem>");
                return;
            }
            string poemFile = args[0];
            string firstWordOfPoem = args[1].ToLower();
            Poems poems = new Poems("http://www.poemhunter.com",
                                    "http://www.poemhunter.com/p/m/l.asp?a=0&l=top500&order=title&p=[TOKEN]",
                                    poemFile);
            poems.Read();
            poems.BuildMarkovChain(firstWordOfPoem);
            poems.TraverseChain(firstWordOfPoem);
        }
    }
}

Poems.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.IO;
using System.Collections;
namespace MachineLearningAndPoems
{
    class Poems
    {
        private string baseURL = "";
        private string seedURL = "";
        private string poemsFile = "";
        private Hashtable htMarkovChain = null;
        class PoemsMarkovChainConnectedWords
        {
            public Hashtable htConnectedWords;
            public LinkedList<string> cummulativeWeightWord;
            public LinkedList<int> cummulativeWeightValue;
            public int totalRange;
            public PoemsMarkovChainConnectedWords()
            {
                this.htConnectedWords = new Hashtable();
                this.cummulativeWeightWord = new LinkedList<string>();
                this.cummulativeWeightValue = new LinkedList<int>();
                this.totalRange = 0;
            }
        }
        private Poems() { }
        public Poems(string baseURL,
                     string seedURL,
                     string poemsFile)
        {
            this.baseURL = baseURL;
            this.seedURL = seedURL;
            this.poemsFile = poemsFile;
            this.htMarkovChain = new Hashtable();
        }
        public void Read()
        {
            int poemIndex = 1;
            int poemPage = 1;
            bool downloadPage = true;
            string statusCode = "";
            string content = "";
            int nPoems = 600;
            if (File.Exists(this.poemsFile))
            {
                return;
            }
            FileInfo fiPoem = new FileInfo(this.poemsFile);
            StreamWriter swPoem = fiPoem.CreateText();
            while (poemIndex <= nPoems)
            {
                if (downloadPage)
                {
                    string url = seedURL.Replace("[TOKEN]", poemPage.ToString());
                    content = Util.GetURL(url,
                                          out statusCode);
                    poemPage++;
                    downloadPage = false;
                }
                string anchorRank = "\"rank\">" + poemIndex.ToString();
                poemIndex++;
                int anchorRankIndex = content.IndexOf(anchorRank);
                if (anchorRankIndex >= 0)
                {
                    string poemURLInitToken = "href=\"";
                    string poemURLEndToken = "\"";
                    string poemURL = "";
                    int poemURLIndex = 0;
                    poemURLIndex = Util.GetValueByTokens(anchorRankIndex,
                                                         content,
                                                         poemURLInitToken,
                                                         poemURLEndToken,
                                                         out poemURL,
                                                         true);
                    if (poemURLIndex >= 0 &&
                        poemURL.Length > 0)
                    {
                        if (this.baseURL.EndsWith("/"))
                        {
                            poemURL = this.baseURL + poemURL.Substring(1);
                        }
                        else
                        {
                            poemURL = this.baseURL + poemURL.Substring(0);
                        }
                        string poemContent = Util.GetURL(poemURL,
                                                         out statusCode);
                        Console.WriteLine("Poem #{0}: {1}", poemIndex - 1, poemURL);
                        string poem = RetrievePoemFromRawPage(poemContent);
                        swPoem.WriteLine(poem);
                        swPoem.WriteLine(" ENDPOEMML ");
                        swPoem.Flush();
                    }
                }
                else
                {
                    downloadPage = true;
                }
            }
            if (swPoem != null)
            {
                swPoem.Close();
            }
        }
        private string RetrievePoemFromRawPage(string pageSource)
        {
            if (String.IsNullOrEmpty(pageSource))
            {
                return null;
            }
            string anchor = "class=\"KonaBody\"";
            int anchorIndex = 0;
            anchorIndex = pageSource.IndexOf(anchor);
            if (anchorIndex < 0)
            {
                return null;
            }
            string poemTokenInit = "<p>";
            string poemTokenEnd = "</p>";
            string poem = "";
            int poemIndex = 0;
            poemIndex = Util.GetValueByTokens(anchorIndex,
                                              pageSource,
                                              poemTokenInit,
                                              poemTokenEnd,
                                              out poem,
                                              true);
            if (poemIndex >= 0 &&
                poem.Length > 0)
            {
                poem = Util.RemoveHTMLTags(poem);
                poem = Util.RemoveEncodingTags(poem);
                poem = poem.ToLower();
                string retVal = "";
                for (int i = 0; i < poem.Length; i++)
                {
                    if((poem[i]>='a' && poem[i] <= 'z') ||
                        poem[i] == '\'' ||
                        poem[i] == ' ')
                    {
                        retVal += poem[i].ToString();
                    }
                }
                return retVal;
            }
            else
            {
                return null;
            }
        }
        private void BuildMarkovChainBuckets(string firstWordOfPoem)
        {
            Hashtable htMarkovChainCopy = (Hashtable)this.htMarkovChain.Clone();
            foreach (string word in htMarkovChainCopy.Keys)
            {
                PoemsMarkovChainConnectedWords poemMarkovChainConnectedWord = (PoemsMarkovChainConnectedWords)htMarkovChain[word];
                poemMarkovChainConnectedWord.totalRange = 0;
                if (word.Equals(firstWordOfPoem))
                {
                    Console.WriteLine("{0}:", word);
                }
                foreach (string connectedWord in poemMarkovChainConnectedWord.htConnectedWords.Keys)
                {
                    int cummulativeWeight = (int)poemMarkovChainConnectedWord.htConnectedWords[connectedWord] + poemMarkovChainConnectedWord.totalRange;
                    poemMarkovChainConnectedWord.cummulativeWeightWord.AddLast(connectedWord);
                    poemMarkovChainConnectedWord.cummulativeWeightValue.AddLast(cummulativeWeight);
                    poemMarkovChainConnectedWord.totalRange = cummulativeWeight;
                    if (word.Equals(firstWordOfPoem))
                    {
                        Console.WriteLine("  {0}: {1}", connectedWord, cummulativeWeight);
                    }
                }
                htMarkovChain[word] = poemMarkovChainConnectedWord;
            }
        }
        public void TraverseChain(string initialWord)
        {
            if (String.IsNullOrEmpty(initialWord) ||
                !htMarkovChain.ContainsKey(initialWord))
            {
                return;
            }
            int numberPhrases = 6;
            int wordsPerPhrase = 4;
            Console.WriteLine();
            Console.WriteLine("Markov Chain Auto-Generated Poem:");
            Console.WriteLine("\"");
            string currentWord = initialWord;
            for (int n = 0; n < numberPhrases; n++)
            {
                for (int i = 0; i < wordsPerPhrase; i++)
                {
                    Console.Write(" {0}", currentWord);
                    if (!htMarkovChain.ContainsKey(currentWord))
                    {
                        Console.Write(",");
                        currentWord = initialWord;
                    }
                    else
                    {
                        PoemsMarkovChainConnectedWords poemMarkovChainConnectedWord = (PoemsMarkovChainConnectedWords)htMarkovChain[currentWord];
                        int coinToss = Util.MyRandom(0, poemMarkovChainConnectedWord.totalRange);
                        for (int wordIndex = 0; wordIndex < poemMarkovChainConnectedWord.cummulativeWeightWord.Count; wordIndex++)
                        {
                            if (coinToss < poemMarkovChainConnectedWord.cummulativeWeightValue.ElementAt<int>(wordIndex))
                            {
                                currentWord = poemMarkovChainConnectedWord.cummulativeWeightWord.ElementAt<string>(wordIndex);
                                break;
                            }
                        }
                    }
                }
                Console.WriteLine();
            }
            Console.WriteLine("\"");
        }
        public void BuildMarkovChain(string firstWordOfPoem)
        {
            FileInfo fi = new FileInfo(this.poemsFile);
            StreamReader sr = fi.OpenText();
            string allPoems = sr.ReadToEnd();
            string[] poemWords = allPoems.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
            for (int i = 0; i < poemWords.Length - 1; i++)
            {
                string fromWord = poemWords[i];
                string toWord = poemWords[i + 1];
                //Some basic cleanup/manipulation of the tokens{
                if (fromWord.Equals(toWord))
                {
                    continue;
                }
                if (fromWord.Equals("ENDPOEMML") ||
                    toWord.Equals("ENDPOEMML") ||
                    fromWord.Trim().Length == 0 ||
                    toWord.Trim().Length == 0 ||
                    fromWord.Equals("ii") ||
                    toWord.Equals("ii") ||
                    fromWord.Equals("iii") ||
                    toWord.Equals("iii"))
                {
                    continue;
                }
                if (fromWord.Length > 1 &&
                   (fromWord.ToLower()[0] < 'a' || fromWord.ToLower()[0] > 'z'))
                {
                    fromWord = fromWord.Substring(1);
                }
                if (toWord.Length > 1 &&
                   (toWord.ToLower()[0] < 'a' || toWord.ToLower()[0] > 'z'))
                {
                    toWord = toWord.Substring(1);
                }
                if (fromWord.Length == 1 &&
                    (fromWord[0] < 'a' || fromWord[0] > 'z'))
                {
                    continue;
                }
                if (toWord.Length == 1 &&
                    (toWord[0] < 'a' || toWord[0] > 'z'))
                {
                    continue;
                }
                //Some basic cleanup/manipulation of the tokens}
                if (!htMarkovChain.ContainsKey(fromWord))
                {
                    PoemsMarkovChainConnectedWords poemMarkovChainConnectedWord = new PoemsMarkovChainConnectedWords();
                    poemMarkovChainConnectedWord.htConnectedWords.Add(toWord, 1);
                    htMarkovChain.Add(fromWord, poemMarkovChainConnectedWord);
                }
                else
                {
                    PoemsMarkovChainConnectedWords poemMarkovChainConnectedWord = (PoemsMarkovChainConnectedWords)htMarkovChain[fromWord];
                    if (!poemMarkovChainConnectedWord.htConnectedWords.ContainsKey(toWord))
                    {
                        poemMarkovChainConnectedWord.htConnectedWords.Add(toWord, 1);
                    }
                    else
                    {
                        poemMarkovChainConnectedWord.htConnectedWords[toWord] = (int)poemMarkovChainConnectedWord.htConnectedWords[toWord] + 1;
                    }
                    htMarkovChain[fromWord] = poemMarkovChainConnectedWord;
                }
            }
            this.BuildMarkovChainBuckets(firstWordOfPoem);
            if (sr != null)
            {
                sr.Close();
            }
        }
    }
}
Util.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Web;
using System.Net;
using System.IO;
namespace MachineLearningAndPoems
{
    class Util
    {
        public static int GetValueByTokens(int initPos, string source, string initToken, string endToken, out string retValue, bool ignoreErrors)
        {
            int newPos = -1;
            retValue = "";
            if ((source == null) || (initToken == null) || (endToken == null))
            {
                return newPos;
            }
            if (initPos >= 0)
            {
                int iPos = source.IndexOf(initToken, initPos);
                if (iPos >= 0)
                {
                    iPos += initToken.Length;
                    int ePos = source.IndexOf(endToken, iPos);
                    if ((ePos >= 0) && (ePos - iPos > 0))
                    {
                        retValue = source.Substring(iPos, ePos - iPos);
                        newPos = ePos;
                    }
                    else if (!ignoreErrors)
                    {
                        Console.WriteLine("  Warning: cound not retrieve token in between '" + initToken + "' and '" + endToken + "' - error 1");
                    }
                }
                else if (!ignoreErrors)
                {
                    Console.WriteLine("  Warning: cound not retrieve token in between '" + initToken + "' and '" + endToken + "' - error 2");
                }
            }
            return newPos;
        }
        public static string RemoveHTMLTags(string html)
        {
            if (html == null)
            {
                return null;
            }
            string retVal = "";
            bool insideHTMLTag = false;
            for (int i = 0; i < html.Length; i++)
            {
                if (html[i] == '<')
                {
                    insideHTMLTag = true;
                    retVal += " "; //add a space
                }
                if (!insideHTMLTag)
                {
                    retVal += html[i].ToString();
                }
                if (html[i] == '>')
                {
                    insideHTMLTag = false;
                }
            }
            return retVal;
        }
        public static string RemoveEncodingTags(string html)
        {
            if (html == null)
            {
                return null;
            }
            string retVal = "";
            bool insideEncodingTag = false;
            for (int i = 0; i < html.Length; i++)
            {
                if (html[i] == '&')
                {
                    insideEncodingTag = true;
                    retVal += " "; //add a space
                }
                if (!insideEncodingTag)
                {
                    retVal += html[i].ToString();
                }
                if (html[i] == ';')
                {
                    insideEncodingTag = false;
                }
            }
            return retVal;
        }
        public static int MyRandom(int minInclusive, int maxExclusive)
        {
            long len = maxExclusive - minInclusive;
            if (len <= 0)
            {
                return -1;
            }
            string guid = Guid.NewGuid().ToString().ToUpper();
            long sum = 0;
            foreach (char c in guid)
            {
                if (c != '-')
                {
                    //1073676287 is a large prime number
                    sum = (1073676287 * sum + ((c >= '0' && c <= '9') ? (int)(c - '0') : (int)(c - 'A' + 10))) % len;
                }
            }
            return (int)(sum + minInclusive);
        }
        public static string GetURL(string urlInput, out string statusCode)
        {
            string url = urlInput;
            string contents = "";
            statusCode = "";
            if (null == url ||
                url.Length == 0 ||
                !url.ToLower().StartsWith("http"))
            {
                return null;
            }
            url = url.Replace("&amp;", "&");
            try
            {
                HttpWebRequest httpGet = (HttpWebRequest)WebRequest.Create(url);
                HttpWebResponse response = (HttpWebResponse)httpGet.GetResponse();
                if (response != null)
                {
                    statusCode = response.StatusDescription;
                    if (statusCode.Equals("not found", StringComparison.CurrentCultureIgnoreCase))
                    {
                        if (response.StatusCode != HttpStatusCode.NotFound)
                        {
                            statusCode += " but not 404";
                        }
                    }
                    Stream objStream = response.GetResponseStream();
                    StreamReader objReader = null;
                    objReader = new StreamReader(objStream);
                    try
                    {
                        contents = objReader.ReadToEnd();
                        if (statusCode.Equals("not found", StringComparison.CurrentCultureIgnoreCase))
                        {
                            if (!contents.ToLower().Contains("not found"))
                            {
                                statusCode += " but not page not found";
                            }
                        }
                    }
                    catch (System.IO.IOException ioException)
                    {
                        Console.WriteLine("Error: IO Exception: {0}", ioException.Message);
                        statusCode = "IO Exception";
                        contents = "";
                    }
                }
            }
            catch (System.Net.WebException webException)
            {
                HttpWebResponse errorWebResponse = (HttpWebResponse)webException.Response;
                if (errorWebResponse != null)
                {
                    statusCode = errorWebResponse.StatusDescription;
                }
                else
                {
                    statusCode = "unknown";
                }
            }
            return contents;
        }
    }
}

Comments

  1. This problem is so awesome, that it's actually one of the famous "Programming Perls" by Jon Bentley (http://www.cs.bell-labs.com/cm/cs/pearls/sec153.html) - probably one of the most admired books by software engineers (along with SICP of course :)

    I love Markov Chains because as everything genius - they are very simple and yet extremely useful. In the days of my youth we used them to model stock prices, create smarter spell checkers, etc. Also they are a great place to start with machine learning and appreciate its beauty.

    Thanks Marcelo, it was great to recall good old days :)

    ReplyDelete

Post a Comment

Popular posts from this blog

Changing the root of a binary tree

Prompt Engineering and LeetCode

ProjectEuler Problem 719 (some hints, but no spoilers)