matlab 为什么我的决策树分类器不工作,函数说没有足够的输入参数

ppcbkaq5  于 2022-12-27  发布在  Matlab
关注(0)|答案(1)|浏览(265)

我已经在Matlab中编写了一个决策树分类器。据我所知,一切都应该工作,逻辑检查。当我试图调用fit方法时,它在我的一个函数上中断,告诉我没有得到正确的输入参数,但我确信我得到了!我已经尝试解决这个问题和类似的与函数和输入参数有关的错误一两天了。我想知道这是否与从构造函数内部调用它们有关,但从主脚本调用它们仍然不起作用。请帮助!

classdef my_ClassificationTree < handle
    
    properties
        X % training examples
        Y % training labels
        MinParentSize % minimum parent node size
        MaxNumSplits % maximum number of splits        
        Verbose % are we printing out debug as we go?
       % MinLeafSize
        CutPoint
        CutPredictorIndex
        Children
        numSplits
        root
    end
    
    methods
        
        % constructor: implementing the fitting phase
        
        function obj = my_ClassificationTree(X, Y, MinParentSize, MaxNumSplits, Verbose)
            obj.X = X;
            obj.Y = Y;
            obj.MinParentSize = MinParentSize;
            obj.MaxNumSplits = MaxNumSplits;
            obj.Verbose = Verbose;
%             obj.Children = zeros(1, 2);
%             obj.CutPoint = 0;
%             obj.CutPredictorIndex = 0;
           % obj.MinLeafSize = MinLeafSize;
            obj.numSplits = 0;
            obj.root = Node(1, size(obj.X,1));
            root = Node(1, size(obj.X,1));
            fit(obj,root);
        end
        
        function node = Node(sIndex,eIndex)
            node.startIndex = sIndex;
            node.endIndex = eIndex;
            node.leaf = false;
            node.Children = 0;
            node.size = eIndex - sIndex + 1;
            node.CutPoint = 0;
            node.CutPredictorIndex = 0;
            node.NodeClass = 0;
        end

        function fit(obj,node)            
            if node.size < obj.MinParentSize || obj.numSplits >= obj.MaxNumSplits
                 % Mark the node as a leaf node
                 node.Leaf = true;
                 % Calculate the majority class label for the examples at this node
                 labels = obj.Y(node.startIndex:node.endIndex); %gather all the labels for the data in the nodes range
                 node.NodeClass = mode(labels); %find the most frequent label and classify the node as such
                 return;
            end
            bestCutPoint = findBestCutPoint(node, obj.X, obj.Y);
            leftChild = Node(node.startIndex, bestCutPoint.CutIndex - 1);
            rightChild = Node(bestSplit.splitIndex, node.endIndex);
            obj.numSplits = obj.numSplits + 1;
            node.CutPoint = bestSplit.CutPoint;
            node.CutPredictorIndex = bestSplit.CutPredictorIndex;
            %Attach the child nodes to the parent node
            node.Children = [leftChild, rightChild];
            % Recursively build the tree for the left and right child nodes
            fit(obj, leftChild);
            fit(obj, rightChild);
        end        

        function bestCutPoint = findBestCutPoint(node, X, labels)
            bestCutPoint.CutPoint = 0;
            bestCutPoint.CutPredictorIndex = 0;
            bestCutPoint.CutIndex = 0;
            bestGDI = Inf; % Initialize the best GDI to a large value
            
            % Loop through all the features
            for i = 1:size(X, 2)
                % Loop through all the unique values of the feature
                values = unique(X(node.startIndex:node.endIndex, i));
                for j = 1:length(values)
                    % Calculate the weighted impurity of the two resulting
                    % cut
                    leftLabels = labels(node.startIndex:node.endIndex, 1);
                    rightLabels = labels(node.startIndex:node.endIndex, 1);
                    leftLabels = leftLabels(X(node.startIndex:node.endIndex, i) < values(j));
                    rightLabels = rightLabels(X(node.startIndex:node.endIndex, i) >= values(j));
                    leftGDI = weightedGDI(leftLabels, labels);
                    rightGDI = weightedGDI(rightLabels, labels);
                    % Calculate the weighted impurity of the split
                    cutGDI = leftGDI + rightGDI;
                    % Update the best split if the current split has a lower GDI
                    if cutGDI < bestGDI
                        bestGDI = cutGDI;
                        bestCutPoint.CutPoint = values(j);
                        bestCutPoint.CutPredictorIndex = i;
                        bestCutPoint.CutIndex = find(X(:, i) == values(j), 1, 'first');
                    end
                end
            end
        end

% the prediction phase:
        function predictions = predict(obj, test_examples)
            
            % get ready to store our predicted class labels:
            predictions = categorical;
            
             % Iterate over each example in X
            for i = 1:size(test_examples, 1)
                % Set the current node to be the root node
                currentNode = obj.root;
                % While the current node is not a leaf node
                while ~currentNode.leaf 
                    % Check the value of the predictor feature specified by the CutPredictorIndex property of the current node
                    value = test_examples(i, currentNode.CutPredictorIndex);
                    % If the value is less than the CutPoint of the current node, set the current node to be the left child of the current node
                    if value < currentNode.CutPoint
                        currentNode = currentNode.Children(1);
                    % If the value is greater than or equal to the CutPoint of the current node, set the current node to be the right child of the current node
                    else
                        currentNode = currentNode.Children(2);
                    end
                end
                % Once the current node is a leaf node, add the NodeClass of the current node to the predictions vector
                predictions(i) = currentNode.NodeClass;
            end
        end
        
        % add any other methods you want on the lines below...

    end
    
end

这是调用myClassificationTree的函数

function m = my_fitctree(train_examples, train_labels, varargin)

    % take an extra name-value pair allowing us to turn debug on:
    p = inputParser;
    addParameter(p, 'Verbose', false);
    %addParameter(p, 'MinLeafSize', false);
    % take an extra name-value pair allowing us to set the minimum
    % parent size (10 by default):
    addParameter(p, 'MinParentSize', 10);
    % take an extra name-value pair allowing us to set the maximum
    % number of splits (number of training examples-1 by default):
    addParameter(p, 'MaxNumSplits', size(train_examples,1) - 1);

    p.parse(varargin{:});
    
    % use the supplied parameters to create a new my_ClassificationTree
    % object:
    
    m = my_ClassificationTree(train_examples, train_labels, ...
        p.Results.MinParentSize, p.Results.MaxNumSplits, p.Results.Verbose);
            
end

这是主代码块中的代码

mym2_dt = my_fitctree(train_examples, train_labels, 'MinParentSize', 10)

这些是错误these are the errors
我希望它构建一个决策树并填充它,但是它在findBestCutPoint函数上中断,我无法修复它

vltsax25

vltsax251#

类方法(构造函数除外)的第一个参数应该是类的示例(例如obj)。您的NodefindBestCutPoint定义应该将obj作为第一个参数。
此外,从其他方法内部调用类方法应该使用语法obj.theMethod,而在您的代码中似乎并非如此。
例如,对Node的调用应该是:

obj.root = obj.Node(1, size(obj.X,1));

并且Node应当定义如下:

function node = Node(obj,sIndex,eIndex)

findBestCutPoint也是如此,注意,在调用中,类示例的引用是隐式传递的,所以不需要在调用中实际包含它。

相关问题