numpy 使用ML sklearn计算树上的水果

hrirmatl  于 2023-03-23  发布在  其他
关注(0)|答案(1)|浏览(138)

这是我的python代码,我试图使用sklearn预测树上的水果数量,但遇到了下面的问题代码:

import cv2
from sklearn.ensemble import RandomForestClassifier

def count_fruits(image):
    # Convert the image into grayscale
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
       
    # Find the contours in the image
    contours, hierarchy = cv2.findContours(gray_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    
    # Count the number of fruit
    count_fruit = 0
    for contour in contours:
        # If the counter is a closed contour, it is a fruit
        if cv2.contourArea(contour) > 1000:
            count_fruit += 1
    return count_fruit

def train_models(image_list, fruit_count_list):
    # Create a Random Forest Classifier
    model = RandomForestClassifier()
    # Train the model on the image list and fruit count list
    model.fit(image_list, fruit_count_list)
    return model

def predict_fruit_count(model, image):
    # get the predicted fruit count
    predict_fruit_count = model.predict(image)
    return predict_fruit_count

if __name__ == '__main__':
    # Get the image list
    image_list = [cv2.imread('4.jpg'), cv2.imread('4.jpg')]
    # Get the fruit count list
    fruit_count_list = [count_fruits(image) for image in image_list]
    # Train the model
    model = train_models(image_list, fruit_count_list)
    # Get the predicted fruit count
    predicted_fruit_count = predict_fruit_count(model, cv2.imread('4.jpg'))
    # Print the predicted fruit count
    print('------------------->{0}'.format(predicted_fruit_count))

我正在使用python sklearn框架创建ML模型来计算树上的果实。我在使用python3运行时遇到错误

after making above code changes new error :

Traceback (most recent call last):
  File "C:\Users\HP\OneDrive\Desktop\Work\ML\ml_fruit.py", line 40, in <module>
    model = train_models(image_list, fruit_count_list)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\HP\OneDrive\Desktop\Work\ML\ml_fruit.py", line 26, in train_models
    model.fit(image_list, fruit_count_list)
  File "C:\Program Files\Python311\Lib\site-packages\sklearn\ensemble\_forest.py", line 345, in fit
    X, y = self._validate_data(
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\sklearn\base.py", line 584, in _validate_data
    X, y = check_X_y(X, y, **check_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\sklearn\utils\validation.py", line 1106, in check_X_y
    X = check_array(
        ^^^^^^^^^^^^
  File "C:\Program Files\Python311\Lib\site-packages\sklearn\utils\validation.py", line 915, in check_array
    raise ValueError(
ValueError: Found array with dim 4. RandomForestClassifier expected <= 2.
yvfmudvl

yvfmudvl1#

看起来你是把图像的阈值传递给findCountour函数,而不是传递一个实际的图像。

gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # Threshold the image
    thresh_image = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY)
       
    # Find the contours in the image
    contours, hierarchy = cv2.findContours(gray_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

相关问题