python-3.x 如何在枚举Flag子类中遍历 *distinct* 标志?

r8uurelv  于 2023-10-21  发布在  Python
关注(0)|答案(3)|浏览(87)

我有一个enum来表示数据集的不同子集,* 以及这些子集的组合 *:

from enum import Flag, auto

class DataSubset(Flag):
    TRAIN = auto()
    TEST = auto()
    VALIDATION = auto()
    EXCLUDED = auto()

    TRAIN_TEST = TRAIN | TEST
    ALL_INCLUDED = TRAIN_TEST | VALIDATION
    ALL = ALL_INCLUDED | EXCLUDED

有没有一种方法可以只遍历不同的标志,而不是命名的组合?即:

[DataSubset.TRAIN, DataSubset.TEST, DataSubset.VALIDATION, DataSubset.EXCLUDED]

我们的目标是能够做到这样的事情:

def get_subsets(subset):
    return [sub for sub in DataSubset.distinct_flags if sub in subset]

然后:

>>> get_subsets(DataSubset.TRAIN)
[DataSubset.TRAIN]
>>> get_subsets(DataSubset.TRAIN_TEST)
[DataSubset.TRAIN, DataSubset.TEST]
>>> get_subsets(DataSubset.ALL)
[DataSubset.TRAIN, DataSubset.TEST, DataSubset.VALIDATION, DataSubset.EXCLUDED]
jdg4fx2g

jdg4fx2g1#

有点傻的解决方案,但你可以使用位旋转黑客测试整数是2的幂,只找到一个位标志。如果你的标志是现有标志的别名,而不是它们的组合,这将包括它们,但它会过滤掉任何没有精确设置一位的标志:

def distinct_flags(enm):
    return [x for x in enm if (x.value & (x.value - 1)) == 0]

当使用它时,会得到以下结果(因为我在IPython中运行它,所以稍微漂亮一些):

>>> distinct_flags(DataSubset)
[<DataSubset.TRAIN: 1>,
 <DataSubset.TEST: 2>,
 <DataSubset.VALIDATION: 4>,
 <DataSubset.EXCLUDED: 8>]

您只需围绕该功能构建get_subsets函数,或者将这两种功能(过滤到单个标志和包含在提供的子集中的标志)合并到现有代码中的if条件中。

ffx8fchx

ffx8fchx2#

Python 3.11可以将Flag值的成员从框中删除,并小心地排除别名。因此,它是没有必要找到一个不同的定义迭代的解决方案。

>>> list(DataSubset.TRAIN)
[<DataSubset.TRAIN: 1>]
>>> list(DataSubset.TRAIN_TEST)
[<DataSubset.TRAIN: 1>, <DataSubset.TEST: 2>]
>>> list(DataSubset.ALL)
[<DataSubset.TRAIN: 1>, <DataSubset.TEST: 2>, <DataSubset.VALIDATION: 4>, <DataSubset.EXCLUDED: 8>]
ibps3vxo

ibps3vxo3#

我可以通过创建一个新的元类,合并this answer on getting powers of 2this answer on changing iterating behavior来实现这一点

from enum import EnumMeta, Flag, auto

class DistinctFlag(EnumMeta):
    def __iter__(cls):
        for x in super().__iter__():
            if (x.value & (x.value-1))==0 and x.value != 0:
                yield x

                
class DataSubset(Flag, metaclass=DistinctFlag):
    """Enum to describe distinct subsets of a modeling dataset"""
    BURN_IN = auto()
    TRAIN = auto()
    TEST = auto()
    HOLDOUT = auto()
    EXCLUDED = auto()
    
    TRAIN_TEST = TRAIN | TEST
    OBS = BURN_IN | TRAIN_TEST
    ALL_INCLUDED = OBS | HOLDOUT
    ALL = ALL_INCLUDED | EXCLUDED

那么:

>>> print(list(DataSubset)
[<DataSubset.BURN_IN: 1>,
 <DataSubset.TRAIN: 2>,
 <DataSubset.TEST: 4>,
 <DataSubset.HOLDOUT: 8>,
 <DataSubset.EXCLUDED: 16>]

相关问题