java中的mallet主题建模

oipij1gg  于 2021-06-27  发布在  Java
关注(0)|答案(1)|浏览(350)

对于机器学习教程/书籍/文章,我发现非常困难的是,当一个模型被解释(甚至是代码)时,你只能在训练(和/或测试)模型之前得到代码。然后就停止了。我找不到从示例(例如主题建模)开始的教程/书籍,它们从数据集开始,训练模型并演示如何使用模型。在下面的代码中,我有一个按主题存储在文件夹中的新闻文章数据集。使用mallet我可以创建模型(并保存它),但它就这样结束了。
我现在怎么用?我给模型提供了一篇文章,作为输出,它给出了主题。请不要参考mallet文档,因为这也不能提供从开始到使用模型的完整示例。
下面的例子摘自《java机器学习》(bostjan kaluza)一书,其中提供了创建模型和保存/加载模型的代码。对我来说这是一个很好的起点,但是如果我现在想使用这个经过训练的模型呢。有人能举一个java的例子吗?不一定要用木槌。

import cc.mallet.types.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.*;
import cc.mallet.topics.*;
import cc.mallet.util.Randoms;

import java.util.*;
import java.util.regex.*;
import java.io.*;

public class TopicModeling {

    public static void main(String[] args) throws Exception {

        String dataFolderPath = "data/bbc";
        String stopListFilePath = "data/stoplists/en.txt";

        ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
        pipeList.add(new Input2CharSequence("UTF-8"));
        Pattern tokenPattern = Pattern.compile("[\\p{L}\\p{N}_]+");
        pipeList.add(new CharSequence2TokenSequence(tokenPattern));
        pipeList.add(new TokenSequenceLowercase());
        pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
        pipeList.add(new TokenSequence2FeatureSequence());
        pipeList.add(new Target2Label());
        SerialPipes pipeline = new SerialPipes(pipeList);

        FileIterator folderIterator = new FileIterator(
                    new File[] {new File(dataFolderPath)},
                    new TxtFilter(),
                    FileIterator.LAST_DIRECTORY);

        // Construct a new instance list, passing it the pipe
        //  we want to use to process instances.
        InstanceList instances = new InstanceList(pipeline);

        // Now process each instance provided by the iterator.
        instances.addThruPipe(folderIterator);

        // Create a model with 100 topics, alpha_t = 0.01, beta_w = 0.01
        //  Note that the first parameter is passed as the sum over topics, while
        //  the second is the parameter for a single dimension of the Dirichlet prior.
        int numTopics = 5;
        ParallelTopicModel model = new ParallelTopicModel(numTopics, 0.01, 0.01);

        model.addInstances(instances);

        // Use two parallel samplers, which each look at one half the corpus and combine
        //  statistics after every iteration.
        model.setNumThreads(4);

        // Run the model for 50 iterations and stop (this is for testing only, 
        //  for real applications, use 1000 to 2000 iterations)
        model.setNumIterations(50);
        model.estimate();

        /*
         * Saving model
         */

        String modelPath = "myTopicModel";

        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream (new File(modelPath+".model")));
        oos.writeObject(model);
        oos.close();     
        oos = new ObjectOutputStream(new FileOutputStream (new File(modelPath+".pipeline")));
        oos.writeObject(pipeline);
        oos.close();     

        System.out.println("Model saved.");

        /*
         * Loading the model
         */
//      ParallelTopicModel model;
//      SerialPipes pipeline;
        ObjectInputStream ois = new ObjectInputStream (new FileInputStream (new File(modelPath+".model")));
        model = (ParallelTopicModel) ois.readObject();
        ois.close();   
        ois = new ObjectInputStream (new FileInputStream (new File(modelPath+".pipeline")));
        pipeline = (SerialPipes) ois.readObject();
        ois.close();   

        System.out.println("Model loaded.");

        // Show the words and topics in the first instance

        // The data alphabet maps word IDs to strings
        Alphabet dataAlphabet = instances.getDataAlphabet();

        FeatureSequence tokens = (FeatureSequence) model.getData().get(0).instance.getData();
        LabelSequence topics = model.getData().get(0).topicSequence;

        Formatter out = new Formatter(new StringBuilder(), Locale.US);
        for (int position = 0; position < tokens.getLength(); position++) {
            out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
        }
        System.out.println(out);

        // Estimate the topic distribution of the first instance, 
        //  given the current Gibbs state.
        double[] topicDistribution = model.getTopicProbabilities(0);

        // Get an array of sorted sets of word ID/count pairs
        ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();

        // Show top 5 words in topics with proportions for the first document
        for (int topic = 0; topic < numTopics; topic++) {
            Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();

            out = new Formatter(new StringBuilder(), Locale.US);
            out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
            int rank = 0;
            while (iterator.hasNext() && rank < 5) {
                IDSorter idCountPair = iterator.next();
                out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
                rank++;
            }
            System.out.println(out);
        }

        /*
         * Testing
         */

        System.out.println("Evaluation");

        // Split dataset
        InstanceList[] instanceSplit= instances.split(new Randoms(), new double[] {0.9, 0.1, 0.0});

        // Use the first 90% for training
        model.addInstances(instanceSplit[0]);
        model.setNumThreads(4);
        model.setNumIterations(50);
        model.estimate();

        // Get estimator
        MarginalProbEstimator estimator = model.getProbEstimator();
        double loglike = estimator.evaluateLeftToRight(instanceSplit[1], 10, false, null);//System.out);
        System.out.println("Total log likelihood: "+loglike);
}

}
/**This class illustrates how to build a simple file filter */
class TxtFilter implements FileFilter {

    /**Test whether the string representation of the file 
     *   ends with the correct extension. Note that {@ref FileIterator}
     *   will only call this filter if the file is not a directory,
     *   so we do not need to test that it is a file.
     */
    public boolean accept(File file) {
        return file.toString().endsWith(".txt");
    }
}
7ajki6be

7ajki6be1#

我还发现ml包有时会忘记“生产模式”,这让我很沮丧。也就是说,lda最常见的用例是您拥有一个集合,并对其进行培训。对于新文档的推断,您始终可以使用文档中描述的命令行,但是如果您需要java接口,则可能需要将一些示例放在一起。您包含的代码支持加载保存的模型,您只需要使用 TopicInferencer 而不是 MarginalProbabilityEstimator . 替换 getProbEstimator()getInferencer() . 的来源 TopicInferencer 有处理示例的示例。你可以用 pipeline 对象将文档字符串导入mallet示例格式。可能看起来像

Instance instance = pipeline.pipe(new Instance(inputText, null, null, null);
double[] distribution = inferencer.getSampledDistribution(instance, 10, 0, 5);

(我没有对此进行测试)这些数字是估计后验概率的合理值,但它们也是粗略的猜测。

相关问题