Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow 如何实现VAE的高斯混合?_Tensorflow_Distribution_Gaussian_Tensorflow Probability_Mixture - Fatal编程技术网

Tensorflow 如何实现VAE的高斯混合?

Tensorflow 如何实现VAE的高斯混合?,tensorflow,distribution,gaussian,tensorflow-probability,mixture,Tensorflow,Distribution,Gaussian,Tensorflow Probability,Mixture,我觉得我真的不知道我在做什么,所以我会描述我认为我在做什么,我想做什么,以及失败的地方 给定一个普通的可变自动编码器: ... net = tf.layers.dense(net, units=code_size * 2, activation=None) mean = net[:, :code_size] std = net[:, code_size:] posterior = tfd.MultivariateNormalDiagWithSoftplusScale(mean, std) net

我觉得我真的不知道我在做什么,所以我会描述我认为我在做什么,我想做什么,以及失败的地方

给定一个普通的可变自动编码器:

...
net = tf.layers.dense(net, units=code_size * 2, activation=None)
mean = net[:, :code_size]
std = net[:, code_size:]
posterior = tfd.MultivariateNormalDiagWithSoftplusScale(mean, std)
net = posterior.sample()
net = tf.layers.dense(net, units=input_size, ...)
...
我想我正在做的是:让神经网络找到一个“均值”和“标准”值,并用它来创建一个正态分布(高斯分布)。 从该分布中采样并将其用于解码器。 换句话说:学习高斯分布的编码

现在我想对混合高斯函数做同样的处理

...
net = tf.layers.dense(net, units=code_size * 2 * code_size, activation=None)

means, stds = tf.split(net, 2, axis=-1)

means = tf.split(means, code_size, axis=-1)
stds = tf.split(stds, code_size, axis=-1)

components = [tfd.MultivariateNormalDiagWithSoftplusScale(means[i], stds[i]) for i in range(code_size)]
probs = [1.0 / code_size] * code_size

gauss_mix = tfd.Mixture(cat=tfd.Categorical(probs=probs), components=components)
net = gauss_mix.sample()
net = tf.layers.dense(net, units=input_size, ...)
...
这对我来说似乎比较直截了当,但它失败了,出现了以下错误:

形状()和(?)不兼容

这似乎来自没有批处理维度的
probs
(我没想到它会需要这个维度)

我认为
probs
定义了组件之间的概率

如果我定义了一个同样具有批处理维度的
probs
,我会得到以下神秘错误,我不知道它应该是什么意思:

维度-1796453376必须大于等于0

我通常会误解一些概念吗?

或者我需要做什么不同的事情?