Latent Dirichlet Allocation with Mallet
03/10/2011
We recently had a PhD candidate from UCI come in and speak to the AI club at Google Irvine to speak about her research on Latent Dirichlet Allocation (LDA). LDA is a topic model and groups words into topics where each article is comprised of a mixture of topics. I was interested to play around with this a bit, so I downloaded Mallet and wrote up some quick code to try making my own LDA model.
package com.benmccann.topicmodel;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import cc.mallet.pipe.CharSequence2TokenSequence;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.pipe.TokenSequenceRemoveStopwords;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.IDSorter;
import cc.mallet.types.InstanceList;
import com.google.inject.Guice;
import com.google.inject.Inject;
import com.google.inject.Injector;
public class Lda {
@Inject private com.benmccann.topicmodel.TextProvider textProvider;
InstanceList createInstanceList(List<String> texts) throws IOException {
ArrayList<Pipe> pipes = new ArrayList<Pipe>();
pipes.add(new CharSequence2TokenSequence());
pipes.add(new TokenSequenceLowercase());
pipes.add(new TokenSequenceRemoveStopwords());
pipes.add(new TokenSequence2FeatureSequence());
InstanceList instanceList = new InstanceList(new SerialPipes(pipes));
instanceList.addThruPipe(new ArrayIterator(texts));
return instanceList;
}
private ParallelTopicModel createNewModel() throws IOException {
List<String> texts = textProvider.getTexts();
InstanceList instanceList = createInstanceList(texts);
int numTopics = instanceList.size() / 5;
ParallelTopicModel model = new ParallelTopicModel(numTopics);
model.addInstances(instanceList);
model.estimate();
return model;
}
ParallelTopicModel getOrCreateModel() throws Exception {
return getOrCreateModel("model");
}
private ParallelTopicModel getOrCreateModel(String directoryPath)
throws Exception {
File directory = new File(directoryPath);
if (!directory.exists()) {
directory.mkdir();
}
File file = new File(directory, "mallet-lda.model");
ParallelTopicModel model = null;
if (!file.exists()) {
model = createNewModel();
model.write(file);
} else {
model = ParallelTopicModel.read(file);
}
return model;
}
public void printTopics() throws Exception {
ParallelTopicModel model = getOrCreateModel();
Alphabet alphabet = model.getAlphabet();
for (TreeSet<IDSorter> set : model.getSortedWords()) {
System.out.print("TOPIC: ");
for (IDSorter s : set) {
System.out.print(alphabet.lookupObject(s.getID()) + ", ");
}
System.out.println();
}
}
public static void main(String[] args) throws Exception {
Injector injector = Guice.createInjector();
Lda lda = injector.getInstance(Lda.class);
lda.printTopics();
}
}
One of the things I found interesting was that you have to specify a number of topics. This is where the ‘art’ of machine learning comes in. With some training data this parameter could be tuned to perform better than my random guesses.
What perfect script. thanks for the sharing
Hi Ben,
Thanks for sharing. I am new to this. Would you mind sharing “com.benmccann.topicmodel.TextProvider”? so that I can run it.
Hello Ben.
Giving a list of keywords extracted from a set of files, I want to find a set of topics using LDA. For this purpose, I have created a keyword x file matrix where each cell indicates the number of occurrences of a given keyword in a given file. How to create an instanceList from this matrix?
Thanks in advance.
Alvine.
Hello. I am tring to run LDA to generate some topics from documents. For this purpose, I have imported the jar of Mallet in Eclipse and written some code in java. But I need some help to import the data in my code. My data are represented as follows in a .txt file:
Document1 X forest=3.4 tree=5 wood=2.85 hammer=1 colour=1 leaf=1.5
Document2 X forest=10 tree=5 wood=2.75 hammer=1 colour=4 leaf=1
Document3 X forest=19 tree=0.90 wood=2 hammer=2 colour=9 leaf=4.3
Document4 X forest=4 tree=5 wood=10 hammer=1 colour=6 leaf=3
Each numeric value in the file is an indication of the number of occurrences of each feature (e.g., forest, tree) multiplied by a given penalty.
To generate instances from the file described above, I use the following code:
public static InstanceList prepareData(String dataPath) throws UnsupportedEncodingException, FileNotFoundException {
// Begin by importing documents from text to feature sequences
ArrayList pipeList = new ArrayList();
// Pipes: lowercase, tokenize, remove stopwords, map to features
pipeList.add( new CharSequenceLowercase() );
pipeList.add( new CharSequence2TokenSequence(Pattern.compile(“[\\p{L}([0-9]*\\.[0-9]+|[0-9]+)_\\=]+”)) );
pipeList.add( new TokenSequenceRemoveStopwords(new File(enFilePath), “UTF-8”,
false, false, false) );
pipeList.add( new TokenSequence2FeatureSequence() );
pipeList.add( new PrintInputAndTarget());
InstanceList instances = new InstanceList (new SerialPipes(pipeList));
Reader fileReader = new InputStreamReader(new FileInputStream(new File(dataPath)), “UTF-8”);
instances.addThruPipe(new CsvIterator (fileReader, Pattern.compile(“^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$”),
3, 2, 1)); // data, label, name fields
return instances;
}
Is my code appropriate to generate the instances to run LDA from my .txt file? If not, how could I modify it? To run LDA, I modified the code available on http://mallet.cs.umass.edu/topics-devel.php so as to take as input the instances computed by the code described above. Apparently, the former code uses The Gibbs sampling method. But does it use the Markov Chain to converge? If not, how can I instruct the code to use the Markov Chain?
Thanks in advance.
Hi,
nice code. I just removed the Injector piece, not really needed and did a small adjustment on the numtopics/5 as is crashes if your list its smaller. I would combine this with deeplearnign4j classifier to have a complete use case. This means get the documents -> classify them into their corresponding categories and then extract main topics to improve precision.
Thanks for the example.
Hi,
I need to extract noun phrases from the text document , Is it possible using Mallet?
Please suggest.