过滤Tensorflow数据集中的NaN值

qlzsbp2j  于 2023-08-06  发布在  其他
关注(0)|答案(3)|浏览(106)

是否有一种简单的方法来过滤tensorflow.data.Dataset示例中包含nan值的所有条目?就像Pandas中的dropna方法一样?

简短的例子:

import numpy as np
import tensorflow as tf

X = tf.data.Dataset.from_tensor_slices([[1,2,3], [0,0,0], [np.nan,np.nan,np.nan], [3,4,5], [np.nan,3,4]])
y = tf.data.Dataset.from_tensor_slices([np.nan, 0, 1, 2, 3])
ds = tf.data.Dataset.zip((X,y))
ds = foo(ds)  # foo(x) = ?
for x in iter(ds): print(str(x))

字符串
foo(x)可以使用什么来获得以下输出:

(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>)
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)


如果你想自己试试,here is Google Colab notebook

tjvv9vkg

tjvv9vkg1#

我有一个与现有答案略有不同的方法。我不使用sum,而是使用tf.reduce_any

filter_nan = lambda x, y: not tf.reduce_any(tf.math.is_nan(x)) and not tf.math.is_nan(y)

ds = tf.data.Dataset.zip((X,y)).filter(filter_nan)

list(ds.as_numpy_iterator())

个字符

2skhul33

2skhul332#

怎么样:

def any_nan(t):
    return tf.reduce_sum(
        tf.cast(
            tf.math.is_nan(t),
            tf.int32,
        )
    ) > tf.constant(0)

>>> ds_filtered = ds.filter(lambda x, y: not any_nan(x) and not any_nan(y))
>>> for x in iter(ds_filtered): print(str(x))
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>)
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)

字符串

kupeojn6

kupeojn63#

列数和机器学习模型性能之间存在一些权衡,最好尽可能多地保留特征列。我建议首先删除目标变量中具有NAN值的所有行,然后删除其他变量中具有剩余NAN值的列。这样,如果特征变量中缺少的条目与缺少的目标值一致,您可以保留更多的列,而不仅仅是删除所有缺少条目的内容。建议是一般性的,您可以选择保留一些缺少很少条目的功能列,而不是删除具有这些条目的行。
这里我假设在你的Tensor中,你有特征和目标变量作为列,y_indx是你的目标的列索引。你也可以使用一个合适的布尔掩码。
以下函数从TensorX中删除目标列中具有nan个值的行。它为删除的行返回一个布尔掩码,但是如果不想保留它,可以从最后一行跳过它。

def drop_na_rows(X, y_indx):
    not_nan = tf.math.logical_not(tf.math.is_nan(X[:, y_indx]))
    return X[not_nan, :], no_nan

字符串
下面的函数删除包含缺失值的列。请注意,对于cols_to_drop参数,您可以使用适合切片Tensor的布尔掩码或索引,但如果您不提供任何内容,则函数将返回列的布尔掩码。或者你也可以跳过。

def drop_na_cols(X, cols_to_drop=None):
    if cols_to_drop is None:
        cols_to_drop = tf.where(tf.reduce_sum(X, axis=1))
    return X[:, cols_to_drop], cols_to_drop

相关问题