这是我的python代码,我试图使用sklearn预测树上的水果数量,但遇到了下面的问题代码:
import cv2
from sklearn.ensemble import RandomForestClassifier
def count_fruits(image):
# Convert the image into grayscale
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Find the contours in the image
contours, hierarchy = cv2.findContours(gray_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# Count the number of fruit
count_fruit = 0
for contour in contours:
# If the counter is a closed contour, it is a fruit
if cv2.contourArea(contour) > 1000:
count_fruit += 1
return count_fruit
def train_models(image_list, fruit_count_list):
# Create a Random Forest Classifier
model = RandomForestClassifier()
# Train the model on the image list and fruit count list
model.fit(image_list, fruit_count_list)
return model
def predict_fruit_count(model, image):
# get the predicted fruit count
predict_fruit_count = model.predict(image)
return predict_fruit_count
if __name__ == '__main__':
# Get the image list
image_list = [cv2.imread('4.jpg'), cv2.imread('4.jpg')]
# Get the fruit count list
fruit_count_list = [count_fruits(image) for image in image_list]
# Train the model
model = train_models(image_list, fruit_count_list)
# Get the predicted fruit count
predicted_fruit_count = predict_fruit_count(model, cv2.imread('4.jpg'))
# Print the predicted fruit count
print('------------------->{0}'.format(predicted_fruit_count))
我正在使用python sklearn框架创建ML模型来计算树上的果实。我在使用python3运行时遇到错误
after making above code changes new error :
Traceback (most recent call last):
File "C:\Users\HP\OneDrive\Desktop\Work\ML\ml_fruit.py", line 40, in <module>
model = train_models(image_list, fruit_count_list)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\HP\OneDrive\Desktop\Work\ML\ml_fruit.py", line 26, in train_models
model.fit(image_list, fruit_count_list)
File "C:\Program Files\Python311\Lib\site-packages\sklearn\ensemble\_forest.py", line 345, in fit
X, y = self._validate_data(
^^^^^^^^^^^^^^^^^^^^
File "C:\Program Files\Python311\Lib\site-packages\sklearn\base.py", line 584, in _validate_data
X, y = check_X_y(X, y, **check_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Program Files\Python311\Lib\site-packages\sklearn\utils\validation.py", line 1106, in check_X_y
X = check_array(
^^^^^^^^^^^^
File "C:\Program Files\Python311\Lib\site-packages\sklearn\utils\validation.py", line 915, in check_array
raise ValueError(
ValueError: Found array with dim 4. RandomForestClassifier expected <= 2.
1条答案
按热度按时间yvfmudvl1#
看起来你是把图像的阈值传递给findCountour函数,而不是传递一个实际的图像。