Keras 神经网络分类特征嵌入转移学习:输入与权重的关联

Keras 神经网络分类特征嵌入转移学习:输入与权重的关联,keras,neural-network,keras-layer,categorical-data,Keras,Neural Network,Keras Layer,Categorical Data,我在Python3中使用带有Tensorflow后端的KerasAPI 我正在创建一个多类分类多层感知器模型,其真正目的是创建分类特征嵌入向量作为模型的一部分,以便它们可以应用于其他机器学习问题。下面是model.summary(),我要提取的权重位于第2层(dx\u cat\u embedding): 我使用了dx\u cat\u weights=model.layers[2]。get\u weights()[0]来检索dx类别的权重,并运行len(dx\u cat\u weights)我可以

我在Python3中使用带有Tensorflow后端的KerasAPI

我正在创建一个多类分类多层感知器模型,其真正目的是创建分类特征嵌入向量作为模型的一部分,以便它们可以应用于其他机器学习问题。下面是
model.summary()
,我要提取的权重位于第2层(
dx\u cat\u embedding
):

我使用了
dx\u cat\u weights=model.layers[2]。get\u weights()[0]
来检索dx类别的权重,并运行
len(dx\u cat\u weights)
我可以验证输出的大小是否与唯一输入的大小相同,而
len(dx\u cat\u weights[0])
验证每个实例是否具有与其关联的25个权重的正确向量

在此模型中,
dx_cat
输入是类别值的标签编码表示。我是否可以假定
model.layers[2].get\u weights()[0]
根据标签编码的类别值对权重输出进行排序?e、 例如,
model.layers[2].get_weights()[0][0]
是否对应编码为
0的类别值标签和
model.layers[2].get_weights()[0][1693]
是否对应编码为
1693的类别值标签


我最终试图创建一个由
dx\u cat
值及其权重组成的人行横道,这样团队中的其他数据科学家就可以为他们自己的项目映射这些预先训练过的
dx\u cat
权重(这些权重并不总是神经网络)。

您可以创建一个子模型,给定一个int返回相应的嵌入@bshelt141这是你要找的吗?
Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
dx_cat (InputLayer)             [(None, 1)]          0                                            
__________________________________________________________________________________________________
memid (InputLayer)              [(None, 1)]          0                                            
__________________________________________________________________________________________________
dx_cat_embedding (Embedding)    (None, 1, 25)        42325       dx_cat[0][0]                     
__________________________________________________________________________________________________
memid_embedding (Embedding)     (None, 1, 50)        67943200    memid[0][0]                      
__________________________________________________________________________________________________
cont_variables (InputLayer)     [(None, 22)]         0                                            
__________________________________________________________________________________________________
flatten (Flatten)               (None, 25)           0           dx_cat_embedding[0][0]           
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 50)           0           memid_embedding[0][0]            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 97)           0           cont_variables[0][0]             
                                                                 flatten[0][0]                    
                                                                 flatten_1[0][0]                  
__________________________________________________________________________________________________
dense (Dense)                   (None, 32)           3136        concatenate[0][0]                
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 16)           528         dense[0][0]                      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 4)            68          dense_1[0][0]                    
__________________________________________________________________________________________________
prediction (Dense)              (None, 1499)         7495        dense_2[0][0]                    
==================================================================================================
Total params: 67,996,752
Trainable params: 67,996,752
Non-trainable params: 0
__________________________________________________________________________________________________