对于机器学习教程/书籍/文章,我发现非常困难的是,当一个模型被解释(甚至是代码)时,你只能在训练(和/或测试)模型之前得到代码。然后就停止了。我找不到从示例(例如主题建模)开始的教程/书籍,它们从数据集开始,训练模型并演示如何使用模型。在下面的代码中,我有一个按主题存储在文件夹中的新闻文章数据集。使用mallet我可以创建模型(并保存它),但它就这样结束了。
我现在怎么用?我给模型提供了一篇文章,作为输出,它给出了主题。请不要参考mallet文档,因为这也不能提供从开始到使用模型的完整示例。
下面的例子摘自《java机器学习》(bostjan kaluza)一书,其中提供了创建模型和保存/加载模型的代码。对我来说这是一个很好的起点,但是如果我现在想使用这个经过训练的模型呢。有人能举一个java的例子吗?不一定要用木槌。
import cc.mallet.types.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.*;
import cc.mallet.topics.*;
import cc.mallet.util.Randoms;
import java.util.*;
import java.util.regex.*;
import java.io.*;
public class TopicModeling {
public static void main(String[] args) throws Exception {
String dataFolderPath = "data/bbc";
String stopListFilePath = "data/stoplists/en.txt";
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add(new Input2CharSequence("UTF-8"));
Pattern tokenPattern = Pattern.compile("[\\p{L}\\p{N}_]+");
pipeList.add(new CharSequence2TokenSequence(tokenPattern));
pipeList.add(new TokenSequenceLowercase());
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
pipeList.add(new TokenSequence2FeatureSequence());
pipeList.add(new Target2Label());
SerialPipes pipeline = new SerialPipes(pipeList);
FileIterator folderIterator = new FileIterator(
new File[] {new File(dataFolderPath)},
new TxtFilter(),
FileIterator.LAST_DIRECTORY);
// Construct a new instance list, passing it the pipe
// we want to use to process instances.
InstanceList instances = new InstanceList(pipeline);
// Now process each instance provided by the iterator.
instances.addThruPipe(folderIterator);
// Create a model with 100 topics, alpha_t = 0.01, beta_w = 0.01
// Note that the first parameter is passed as the sum over topics, while
// the second is the parameter for a single dimension of the Dirichlet prior.
int numTopics = 5;
ParallelTopicModel model = new ParallelTopicModel(numTopics, 0.01, 0.01);
model.addInstances(instances);
// Use two parallel samplers, which each look at one half the corpus and combine
// statistics after every iteration.
model.setNumThreads(4);
// Run the model for 50 iterations and stop (this is for testing only,
// for real applications, use 1000 to 2000 iterations)
model.setNumIterations(50);
model.estimate();
/*
* Saving model
*/
String modelPath = "myTopicModel";
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream (new File(modelPath+".model")));
oos.writeObject(model);
oos.close();
oos = new ObjectOutputStream(new FileOutputStream (new File(modelPath+".pipeline")));
oos.writeObject(pipeline);
oos.close();
System.out.println("Model saved.");
/*
* Loading the model
*/
// ParallelTopicModel model;
// SerialPipes pipeline;
ObjectInputStream ois = new ObjectInputStream (new FileInputStream (new File(modelPath+".model")));
model = (ParallelTopicModel) ois.readObject();
ois.close();
ois = new ObjectInputStream (new FileInputStream (new File(modelPath+".pipeline")));
pipeline = (SerialPipes) ois.readObject();
ois.close();
System.out.println("Model loaded.");
// Show the words and topics in the first instance
// The data alphabet maps word IDs to strings
Alphabet dataAlphabet = instances.getDataAlphabet();
FeatureSequence tokens = (FeatureSequence) model.getData().get(0).instance.getData();
LabelSequence topics = model.getData().get(0).topicSequence;
Formatter out = new Formatter(new StringBuilder(), Locale.US);
for (int position = 0; position < tokens.getLength(); position++) {
out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
}
System.out.println(out);
// Estimate the topic distribution of the first instance,
// given the current Gibbs state.
double[] topicDistribution = model.getTopicProbabilities(0);
// Get an array of sorted sets of word ID/count pairs
ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();
// Show top 5 words in topics with proportions for the first document
for (int topic = 0; topic < numTopics; topic++) {
Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
out = new Formatter(new StringBuilder(), Locale.US);
out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
int rank = 0;
while (iterator.hasNext() && rank < 5) {
IDSorter idCountPair = iterator.next();
out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
rank++;
}
System.out.println(out);
}
/*
* Testing
*/
System.out.println("Evaluation");
// Split dataset
InstanceList[] instanceSplit= instances.split(new Randoms(), new double[] {0.9, 0.1, 0.0});
// Use the first 90% for training
model.addInstances(instanceSplit[0]);
model.setNumThreads(4);
model.setNumIterations(50);
model.estimate();
// Get estimator
MarginalProbEstimator estimator = model.getProbEstimator();
double loglike = estimator.evaluateLeftToRight(instanceSplit[1], 10, false, null);//System.out);
System.out.println("Total log likelihood: "+loglike);
}
}
/**This class illustrates how to build a simple file filter */
class TxtFilter implements FileFilter {
/**Test whether the string representation of the file
* ends with the correct extension. Note that {@ref FileIterator}
* will only call this filter if the file is not a directory,
* so we do not need to test that it is a file.
*/
public boolean accept(File file) {
return file.toString().endsWith(".txt");
}
}
1条答案
按热度按时间7ajki6be1#
我还发现ml包有时会忘记“生产模式”,这让我很沮丧。也就是说,lda最常见的用例是您拥有一个集合,并对其进行培训。对于新文档的推断,您始终可以使用文档中描述的命令行,但是如果您需要java接口,则可能需要将一些示例放在一起。您包含的代码支持加载保存的模型,您只需要使用
TopicInferencer
而不是MarginalProbabilityEstimator
. 替换getProbEstimator()
与getInferencer()
. 的来源TopicInferencer
有处理示例的示例。你可以用pipeline
对象将文档字符串导入mallet示例格式。可能看起来像(我没有对此进行测试)这些数字是估计后验概率的合理值,但它们也是粗略的猜测。