保存keras预处理图层

xyhw6mcr  于 2022-11-13  发布在  其他
关注(0)|答案(1)|浏览(179)

我有一个模型,我在模型本身之外做不同的预处理。
预处理的一部分是使用基于keras的类别编码器,其中:

cat_index = tf.keras.layers.StringLookup(vocabulary=cat_word_list)
cat_encoder = tf.keras.layers.CategoryEncoding(num_tokens=cat_index.vocabulary_size(), output_mode="one_hot")

我应用这个比用

encoded_cat = cat_encoder(cat_index(data['cat_val'])).numpy()
encoded_cat = pd.DataFrame(encoded_cat, columns=['cat_' + str(i) for i in range(len(encoded_cat[0]))]).astype('int64')

data = pd.merge(data, encoded_cat, left_index=True, right_index=True)
data.drop(columns=['cat_val'], inplace=True)

我的Pandas数据框。
现在我想存储我的模型,为了存储模型,我还必须存储2个预处理层cat_indexcat_encoder。不幸的是,我无法弄清楚如何将这些层存储在文件系统中。如果我尝试使用保存功能,我会得到
'CategoryEncoding'对象没有'保存'属性
如何将这样的预处理层存储到文件系统中,以便在推理过程中可以重用它?
我想到的一个解决方法是存储cat_word_list并重新创建层,但我希望有一种更基于keras的方法。

gdrx4gfi

gdrx4gfi1#

使用get_config layers方法获取配置:
返回层的配置。
一个层配置是一个Python字典(可序列化),包含一个层的配置。相同的层可以在以后从这个配置中重新示例化(没有它的训练权重)。
例如:

cat_word_list = ['cat', 'tiger', 'lion', 'dog']
cat_index = tf.keras.layers.StringLookup(vocabulary=cat_word_list)
cat_encoder = tf.keras.layers.CategoryEncoding(num_tokens=cat_index.vocabulary_size(), output_mode="one_hot")

cat_index_config = cat_index.get_config()
cat_encoder_config = cat_encoder.get_config()

其中应包含重新创建层所需的所有信息:

cat_index_config

输出量:

{'name': 'string_lookup',
 'trainable': True,
 'dtype': 'int64',
 'invert': False,
 'max_tokens': None,
 'num_oov_indices': 1,
 'oov_token': '[UNK]',
 'mask_token': None,
 'output_mode': 'int',
 'sparse': False,
 'pad_to_max_tokens': False,
 'vocabulary': ListWrapper(['cat', 'tiger', 'lion', 'dog']),
 'idf_weights': None,
 'encoding': 'utf-8'}

可以按如下方式重新创建层:

cat_index_2 = tf.keras.layers.StringLookup(**cat_index_config)
cat_encoder_2 = tf.keras.layers.CategoryEncoding(**cat_encoder_config)

各层具有相同的配置,例如

cat_index_2.get_config()

输出量:

{'name': 'string_lookup',
 'trainable': True,
 'dtype': 'int64',
 'invert': False,
 'max_tokens': None,
 'num_oov_indices': 1,
 'oov_token': '[UNK]',
 'mask_token': None,
 'output_mode': 'int',
 'sparse': False,
 'pad_to_max_tokens': False,
 'vocabulary': ListWrapper(['cat', 'tiger', 'lion', 'dog']),
 'idf_weights': None,
 'encoding': 'utf-8'}

相关问题