weka.core.Instances.testCV()方法的使用及代码示例

x33g5p2x  于2022-01-21 转载在 其他  
字(9.2k)|赞(0)|评价(0)|浏览(82)

本文整理了Java中weka.core.Instances.testCV()方法的一些代码示例,展示了Instances.testCV()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Instances.testCV()方法的具体详情如下:
包路径:weka.core.Instances
类名称:Instances
方法名:testCV

Instances.testCV介绍

[英]Creates the test set for one fold of a cross-validation on the dataset.
[中]为数据集上的一次交叉验证创建测试集。

代码示例

代码示例来源:origin: com.googlecode.obvious/obviousx-weka

@Override
public Instances testCV(int arg0, int arg1) {
 return super.testCV(arg0, arg1);
}

代码示例来源:origin: droidefense/engine

public Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {
  Instances[][] split = new Instances[2][numberOfFolds];
  for (int i = 0; i < numberOfFolds; i++) {
    split[0][i] = data.trainCV(numberOfFolds, i);
    split[1][i] = data.testCV(numberOfFolds, i);
  }
  return split;
}

代码示例来源:origin: nz.ac.waikato.cms.weka/weka-stable

/**
 * Generate a bunch of predictions ready for processing, by performing a
 * cross-validation on the supplied dataset.
 * 
 * @param classifier the Classifier to evaluate
 * @param data the dataset
 * @param numFolds the number of folds in the cross-validation.
 * @exception Exception if an error occurs
 */
public ArrayList<Prediction> getCVPredictions(Classifier classifier,
 Instances data, int numFolds) throws Exception {
 ArrayList<Prediction> predictions = new ArrayList<Prediction>();
 Instances runInstances = new Instances(data);
 Random random = new Random(m_Seed);
 runInstances.randomize(random);
 if (runInstances.classAttribute().isNominal() && (numFolds > 1)) {
  runInstances.stratify(numFolds);
 }
 for (int fold = 0; fold < numFolds; fold++) {
  Instances train = runInstances.trainCV(numFolds, fold, random);
  Instances test = runInstances.testCV(numFolds, fold);
  ArrayList<Prediction> foldPred = getTrainTestPredictions(classifier,
   train, test);
  predictions.addAll(foldPred);
 }
 return predictions;
}

代码示例来源:origin: Waikato/weka-trunk

/**
 * Generate a bunch of predictions ready for processing, by performing a
 * cross-validation on the supplied dataset.
 * 
 * @param classifier the Classifier to evaluate
 * @param data the dataset
 * @param numFolds the number of folds in the cross-validation.
 * @exception Exception if an error occurs
 */
public ArrayList<Prediction> getCVPredictions(Classifier classifier,
 Instances data, int numFolds) throws Exception {
 ArrayList<Prediction> predictions = new ArrayList<Prediction>();
 Instances runInstances = new Instances(data);
 Random random = new Random(m_Seed);
 runInstances.randomize(random);
 if (runInstances.classAttribute().isNominal() && (numFolds > 1)) {
  runInstances.stratify(numFolds);
 }
 for (int fold = 0; fold < numFolds; fold++) {
  Instances train = runInstances.trainCV(numFolds, fold, random);
  Instances test = runInstances.testCV(numFolds, fold);
  ArrayList<Prediction> foldPred = getTrainTestPredictions(classifier,
   train, test);
  predictions.addAll(foldPred);
 }
 return predictions;
}

代码示例来源:origin: net.sf.meka.thirdparty/mulan

private MultipleEvaluation innerCrossValidate(MultiLabelLearner learner, MultiLabelInstances data, boolean hasMeasures, List<Measure> measures, int someFolds) {
    Evaluation[] evaluation = new Evaluation[someFolds];

    Instances workingSet = new Instances(data.getDataSet());
    workingSet.randomize(new Random(seed));
    for (int i = 0; i < someFolds; i++) {
      System.out.println("Fold " + (i + 1) + "/" + someFolds);
      try {
        Instances train = workingSet.trainCV(someFolds, i);
        Instances test = workingSet.testCV(someFolds, i);
        MultiLabelInstances mlTrain = new MultiLabelInstances(train, data.getLabelsMetaData());
        MultiLabelInstances mlTest = new MultiLabelInstances(test, data.getLabelsMetaData());
        MultiLabelLearner clone = learner.makeCopy();
        clone.build(mlTrain);
        if (hasMeasures)
          evaluation[i] = evaluate(clone, mlTest, measures);
        else
          evaluation[i] = evaluate(clone, mlTest);
      } catch (Exception ex) {
        Logger.getLogger(Evaluator.class.getName()).log(Level.SEVERE, null, ex);
      }
    }
    MultipleEvaluation me = new MultipleEvaluation(evaluation, data);
    me.calculateStatistics();
    return me;
  }
}

代码示例来源:origin: Waikato/weka-trunk

/**
 * Method for building a pruneable classifier tree.
 *
 * @param data the data to build the tree from 
 * @throws Exception if tree can't be built successfully
 */
public void buildClassifier(Instances data) 
   throws Exception {
 // remove instances with missing class
 data = new Instances(data);
 data.deleteWithMissingClass();
 
 Random random = new Random(m_seed);
 data.stratify(numSets);
 buildTree(data.trainCV(numSets, numSets - 1, random),
    data.testCV(numSets, numSets - 1), !m_cleanup);
 if (pruneTheTree) {
  prune();
 }
 if (m_cleanup) {
  cleanup(new Instances(data, 0));
 }
}

代码示例来源:origin: nz.ac.waikato.cms.weka/weka-stable

/**
 * Method for building a pruneable classifier tree.
 *
 * @param data the data to build the tree from 
 * @throws Exception if tree can't be built successfully
 */
public void buildClassifier(Instances data) 
   throws Exception {
 // remove instances with missing class
 data = new Instances(data);
 data.deleteWithMissingClass();
 
 Random random = new Random(m_seed);
 data.stratify(numSets);
 buildTree(data.trainCV(numSets, numSets - 1, random),
    data.testCV(numSets, numSets - 1), !m_cleanup);
 if (pruneTheTree) {
  prune();
 }
 if (m_cleanup) {
  cleanup(new Instances(data, 0));
 }
}

代码示例来源:origin: net.sf.meka.thirdparty/mulan

/**
 * Automatically selects a threshold based on training set performance
 * evaluated using cross-validation
 *
 * @param measure performance is evaluated based on this parameter
 * @param folds number of cross-validation folds
 * @throws InvalidDataFormatException
 * @throws Exception
 */
private void autoTuneThreshold(MultiLabelInstances trainingData, BipartitionMeasureBase measure, int folds) throws InvalidDataFormatException, Exception {
  if (folds < 2) {
    throw new IllegalArgumentException("folds should be more than 1");
  }
  double[] totalDiff = new double[numLabels + 1];
  LabelsMetaData labelsMetaData = trainingData.getLabelsMetaData();
  MultiLabelLearner tempLearner = foldLearner.makeCopy();
  for (int f = 0; f < folds; f++) {
    Instances train = trainingData.getDataSet().trainCV(folds, f);
    MultiLabelInstances trainMulti = new MultiLabelInstances(train, labelsMetaData);
    Instances test = trainingData.getDataSet().testCV(folds, f);
    MultiLabelInstances testMulti = new MultiLabelInstances(test, labelsMetaData);
    tempLearner.build(trainMulti);
    double[] diff = computeThreshold(tempLearner, testMulti, measure);
    for (int k = 0; k < diff.length; k++) {
      totalDiff[k] += diff[k];
    }
  }
  t = Utils.minIndex(totalDiff);
}

代码示例来源:origin: nz.ac.waikato.cms.weka/weka-stable

Instances test = allData.testCV(m_numFoldsBoosting, i);

代码示例来源:origin: nz.ac.waikato.cms.weka/meka

/**
 * CVModel - Split D into train/test folds, and then train and evaluate on each one.
 * @param    h         a multi-dim. classifier
 * @param    D           data
 * @param    numFolds test data
 * @param    top         Threshold OPtion (pertains to multi-label data only)
 * @return    an array of 'numFolds' Results
 */
public static Result[] cvModel(MultilabelClassifier h, Instances D, int numFolds, String top) throws Exception {
  Result r[] = new Result[numFolds];
  for(int i = 0; i < numFolds; i++) {
    Instances D_train = D.trainCV(numFolds,i);
    Instances D_test = D.testCV(numFolds,i);
    if (h.getDebug()) System.out.println(":- Fold ["+i+"/"+numFolds+"] -: "+MLUtils.getDatasetName(D)+"\tL="+D.classIndex()+"\tD(t:T)=("+D_train.numInstances()+":"+D_test.numInstances()+")\tLC(t:T)="+Utils.roundDouble(MLUtils.labelCardinality(D_train,D.classIndex()),2)+":"+Utils.roundDouble(MLUtils.labelCardinality(D_test,D.classIndex()),2)+")");
    r[i] = evaluateModel(h, D_train, D_test, top);
  }
  return r;
}

代码示例来源:origin: nz.ac.waikato.cms.weka/weka-stable

instances = getInputFormat().testCV(m_NumFolds, m_Fold - 1);
} else {
 instances = getInputFormat().trainCV(m_NumFolds, m_Fold - 1);

代码示例来源:origin: net.sf.meka.thirdparty/mulan

Instances temp = transformed.testCV(folds, i);
Instances test = new Instances(data.getDataSet(), 0);
for (int j=0; j<temp.numInstances(); j++) {

代码示例来源:origin: nz.ac.waikato.cms.weka/weka-stable

Instances test = newData.testCV(m_NumFolds, j);
 for (int i = 0; i < test.numInstances(); i++) {
metaData.add(metaInstance(test.instance(i)));

代码示例来源:origin: net.sf.meka.thirdparty/mulan

protected void buildInternal(MultiLabelInstances trainingData) throws Exception {
  baseLearner.build(trainingData);
  if (folds == 0) {
    threshold = computeThreshold(baseLearner, trainingData, measure);
  } else {
    LabelsMetaData labelsMetaData = trainingData.getLabelsMetaData();
    double[] thresholds = new double[folds];
    for (int f = 0; f < folds; f++) {
      Instances train = trainingData.getDataSet().trainCV(folds, f);
      MultiLabelInstances trainMulti = new MultiLabelInstances(train, labelsMetaData);
      Instances test = trainingData.getDataSet().testCV(folds, f);
      MultiLabelInstances testMulti = new MultiLabelInstances(test, labelsMetaData);
      MultiLabelLearner tempLearner = foldLearner.makeCopy();
      tempLearner.build(trainMulti);
      thresholds[f] = computeThreshold(tempLearner, testMulti, measure);
    }
    threshold = Utils.mean(thresholds);
  }
}

代码示例来源:origin: net.sf.meka.thirdparty/mulan

Instances test = trainingSet.getDataSet().testCV(kFoldsCV, i);
MultiLabelInstances mlTest = new MultiLabelInstances(test, trainingSet.getLabelsMetaData());
MultiLabelLearner learner = baseLearner.makeCopy();

代码示例来源:origin: Waikato/weka-trunk

Instances test = newData.testCV(m_NumFolds, j);
if (baseClassifiersImplementMoreEfficientBatchPrediction()) {
 metaData.addAll(metaInstances(test));

代码示例来源:origin: olehmberg/winter

Classifier copiedClassifier = AbstractClassifier.makeCopy(classifier);
copiedClassifier.buildClassifier(train);
Instances test = data.testCV(numFolds, i);
evaluateModel(copiedClassifier, test, forPredictionsPrinting);

代码示例来源:origin: nz.ac.waikato.cms.weka/weka-stable

Instances test = trainingSet.testCV(5, j);

代码示例来源:origin: Waikato/weka-trunk

Instances test = trainingSet.testCV(5, j);

代码示例来源:origin: nz.ac.waikato.cms.weka/grading

Instances test = newData.testCV(m_NumFolds, j);

相关文章

Instances类方法