python-3.x 使用__init__参数自动填充类属性

wz3gfoph  于 2023-01-18  发布在  Python
关注(0)|答案(2)|浏览(159)

我有一个类,它可以用很多参数初始化,并且它可以作为一个添加方法不断增长。有没有一种方法可以自动地将int方法中的所有位置参数添加到对象的属性中?例如;

class trainer:
    
    def __int__(self, model="unet", encoder_name="resnet18", encoder_weights="imagenet",
              in_channels=3, num_classes=1, loss="jaccard",
              ignore_index=0, learning_rate=1e4, learning_rate_schedule_patience=10,
              ignore_zeros=True):

        # authomatically add the initial properties
        self.model = model
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights
        self.in_channels = in_channels
        self.num_classes = num_classes
        .
        .
        .
        self.ignore_zeros = ignore_zeros
yfwxisqw

yfwxisqw1#

这是与您的__init__对应的dataclass
基本上,您只需声明一个带有注解属性的类,可能还有一个默认值,@dataclass装饰器将为您生成样板代码,如__init____repr__,我建议您阅读更多的documentation

**PS.**类名通常是PascalCase(或CapWords),所以我为您做了更改。

from dataclasses import dataclass

@dataclass
class Trainer:
    model: str = "unet"
    encoder_name: str = "resnet18"
    encoder_weights: str = "imagenet"
    in_channels: int = 3
    num_classes: int = 1
    loss: str = "jaccard"
    ignore_index: int = 0
    learning_rate: float = 1e4
    learning_rate_schedule_patience: int = 10
    ignore_zeros: bool = True
lsmepo6l

lsmepo6l2#

另一种方法是循环传递给__init__的参数,然后使用setattrself保存变量。
这里的优点是它默认为您的默认kwargs,但更新为任何输入kwargs(在本例中为num_classes)-

class trainer():
    
    def __init__(self, model="unet", encoder_name="resnet18", encoder_weights="imagenet",
              in_channels=3, num_classes=1, loss="jaccard",
              ignore_index=0, learning_rate=1e4, learning_rate_schedule_patience=10,
              ignore_zeros=True):
        
        # Loop through params and setattr v to self.k
        for k,v in locals().items():
            if k!='self':
                setattr(self, k, v)
                
#checking the self dictionary                
trainer(num_classes=10).__dict__
{'model': 'unet',
 'encoder_name': 'resnet18',
 'encoder_weights': 'imagenet',
 'in_channels': 3,
 'num_classes': 10,               #<--------------
 'loss': 'jaccard',
 'ignore_index': 0,
 'learning_rate': 10000.0,
 'learning_rate_schedule_patience': 10,
 'ignore_zeros': True}

相关问题