tensorflow中是否有类似于Pytorch中load_state_dict()的函数?

tensorflow中是否有类似于Pytorch中load_state_dict()的函数?,tensorflow,neural-network,deep-learning,pytorch,Tensorflow,Neural Network,Deep Learning,Pytorch,正如前面所描述的,我想知道tensorflow中是否有类似于Pytorch中的load_state_dict()函数。要演示场景,请参考以下代码: # Suppose we have two correctly initialized neural networks: net2 and net1 # Using Pytorch net2.load_state_dict(net1.state_dict()) 有人知道吗 以下代码可能有助于在tensorflow中实现相同的功能: 保存模型 w1

正如前面所描述的,我想知道tensorflow中是否有类似于Pytorch中的load_state_dict()函数。要演示场景,请参考以下代码:

# Suppose we have two correctly initialized neural networks: net2 and net1
# Using Pytorch
net2.load_state_dict(net1.state_dict())

有人知道吗

以下代码可能有助于在tensorflow中实现相同的功能:

保存模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta


sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)
恢复模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta


sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

我就这样(几个月前)解决了我的问题!无论如何,谢谢你的回答!