scipy 哪个求解器抛出此警告?

a64a0gku  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(146)

此警告弹出10次。行号在456和305之间变化:

C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize\_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
  warn('The line search algorithm did not converge', LineSearchWarning)

我正在用这些参数进行网格搜索:

logistic_regression_grid = {
    "class_weight": ["balanced"], 
    "max_iter":     [100000],
    "solver":       ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
    "random_state": [0]
}

所以,问题是哪个求解器抛出了警告?有可能确定吗?

vfhzx4xs

vfhzx4xs1#

我使用了虹膜集,并设置了max_iter=10,以有意地引发收敛警告。由于您只对解算器感兴趣,我在解算器上循环,而不使用网格搜索,并且我能够使用warnings库和sklearn.exceptions包打印解算器不收敛。以下是我的代码:

import warnings
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Your logistic regression grid

logistic_regression_grid = {
    "class_weight": ["balanced"], 
    "max_iter":     [100000],
    "solver":       ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
    "random_state": [0]
}

# Load the Iris dataset

iris = load_iris()
X, y = iris.data, iris.target

# Split the data into training and testing sets

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Loop over the solvers and capture warnings

for solver in logistic_regression_grid["solver"]:
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")

        # Fit logistic regression model with the current solver

        model = LogisticRegression(class_weight="balanced", max_iter=10, solver=solver, random_state=0)
        model.fit(X_train, y_train)

        # Check if any warning was generated

        if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
            print(f"Solver '{solver}' did not converge.")

下面是我得到的输出:

Solver 'lbfgs' did not converge.
Solver 'newton-cg' did not converge.
Solver 'sag' did not converge.
Solver 'saga' did not converge.

相关问题