pytorch 如果我有很多块,如何正确定义tf.variable

e7arh2l6  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(106)

我刚开始从pytorch到tensorflow的转换,在设计残差块时遇到了一些问题,我有一个残差组,包含了很多残差块,每个ack块包含了两个自定义层,我很困扰如何定义每层call()函数中需要作为操作一部分的变量。

我尝试使用self.W = tf.Vaiable()来定义变量,但是这样,当我初始化剩余单元组时,self.W将不断被覆盖,当我尝试使用self.W在每一层的调用函数中提取该参数时,我得到None。
在pytorch中,我可以简单地使用register_parameters定义init中的变量,并使用self.W在forward函数中提取它。
任何熟悉tensorflow 的人都可以帮助我吗?谢谢。

b4lqfgs4

b4lqfgs41#

您可以使用以下代码定义变量

class M(tf.Module):
  def __call__(self, x):
    self.v = tf.Variable(x)
    return self.v

谢谢你。

相关问题