Python Tensorflow-将Tensor扩展到3D的更好方法

Python Tensorflow-将Tensor扩展到3D的更好方法,python,tensorflow2.0,Python,Tensorflow2.0,我有0-3维的输入张量,并且总是希望输出到3D张量(用于tf.einsum函数,其中我不能使用广播),轴由内而外填充。有没有比下面的(丑陋的)条件更好的方法?我通读了tf.expand_dims、tf.reformate和tf.broadcast_to,但找不到任何允许基于不同维度输入张量的动态形状的内容 import tensorflow as tf def broadcast_cash_flows(x): shape = tf.shape(x) dimensions =

我有0-3维的输入张量,并且总是希望输出到3D张量(用于
tf.einsum
函数,其中我不能使用广播),轴由内而外填充。有没有比下面的(丑陋的)条件更好的方法?我通读了
tf.expand_dims
tf.reformate
tf.broadcast_to
,但找不到任何允许基于不同维度输入张量的动态形状的内容

import tensorflow as tf


def broadcast_cash_flows(x):
    shape = tf.shape(x)
    dimensions = len(shape)
    return tf.cond(dimensions == 0,
                   lambda: cf_0d(x),
                   lambda: tf.cond(dimensions == 1,
                                   lambda: cf_1d(x),
                                   lambda: tf.cond(dimensions == 2,
                                                   lambda: cf_2d(x),
                                                   lambda: x)))

def cf_0d(x):
    return tf.expand_dims(tf.expand_dims(tf.expand_dims(x,0),0),0)

def cf_1d(x):
    return tf.expand_dims(tf.expand_dims(x,0),0)

def cf_2d(x):
    return tf.expand_dims(x,0)


cf0 = tf.constant(2.0)
print(broadcast_cash_flows(cf0))

cf1 = tf.constant([2.0, 1.0, 3.0])
print(broadcast_cash_flows(cf1))

cf2 = tf.constant([[2.0, 1.0, 3.0],
                   [3.0, 2.0, 4.0]])
print(broadcast_cash_flows(cf2))

cf3 = tf.constant([[[2.0, 1.0, 3.0],
                    [3.0, 2.0, 4.0]],
                    [[2.0, 1.0, 3.0],
                    [3.0, 2.0, 4.0]]])
print(broadcast_cash_flows(cf3))

tf.expand_dims
在添加一个维度时非常方便

tf.newaxis
在一次操作中添加多个维度(而不是多次调用tf.expand_dims)时非常方便

修改的代码-

import tensorflow as tf

def broadcast_cash_flows(x):
    shape = tf.shape(x)
    dimensions = len(shape)
    if(dimensions == 0):
      return x[tf.newaxis,tf.newaxis,tf.newaxis]
    elif(dimensions == 1):
      return x[tf.newaxis,tf.newaxis,:]
    elif(dimensions == 2):
      return x[tf.newaxis,:,:]
    else:
      return x

cf0 = tf.constant(2.0)
print(broadcast_cash_flows(cf0))

cf1 = tf.constant([2.0, 1.0, 3.0])
print(broadcast_cash_flows(cf1))

cf2 = tf.constant([[2.0, 1.0, 3.0],
                   [3.0, 2.0, 4.0]])
print(broadcast_cash_flows(cf2))

cf3 = tf.constant([[[2.0, 1.0, 3.0],
                    [3.0, 2.0, 4.0]],
                    [[2.0, 1.0, 3.0],
                    [3.0, 2.0, 4.0]]])
print(cf3.shape)
print(broadcast_cash_flows(cf3))
输出-

tf.Tensor([[[2.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[2. 1. 3.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor(
[[[2. 1. 3.]
  [3. 2. 4.]]], shape=(1, 2, 3), dtype=float32)
(2, 2, 3)
tf.Tensor(
[[[2. 1. 3.]
  [3. 2. 4.]]

 [[2. 1. 3.]
  [3. 2. 4.]]], shape=(2, 2, 3), dtype=float32)

tf.expand_dims
在添加一个维度时非常方便

tf.newaxis
在一次操作中添加多个维度(而不是多次调用tf.expand_dims)时非常方便

修改的代码-

import tensorflow as tf

def broadcast_cash_flows(x):
    shape = tf.shape(x)
    dimensions = len(shape)
    if(dimensions == 0):
      return x[tf.newaxis,tf.newaxis,tf.newaxis]
    elif(dimensions == 1):
      return x[tf.newaxis,tf.newaxis,:]
    elif(dimensions == 2):
      return x[tf.newaxis,:,:]
    else:
      return x

cf0 = tf.constant(2.0)
print(broadcast_cash_flows(cf0))

cf1 = tf.constant([2.0, 1.0, 3.0])
print(broadcast_cash_flows(cf1))

cf2 = tf.constant([[2.0, 1.0, 3.0],
                   [3.0, 2.0, 4.0]])
print(broadcast_cash_flows(cf2))

cf3 = tf.constant([[[2.0, 1.0, 3.0],
                    [3.0, 2.0, 4.0]],
                    [[2.0, 1.0, 3.0],
                    [3.0, 2.0, 4.0]]])
print(cf3.shape)
print(broadcast_cash_flows(cf3))
输出-

tf.Tensor([[[2.]]], shape=(1, 1, 1), dtype=float32)
tf.Tensor([[[2. 1. 3.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor(
[[[2. 1. 3.]
  [3. 2. 4.]]], shape=(1, 2, 3), dtype=float32)
(2, 2, 3)
tf.Tensor(
[[[2. 1. 3.]
  [3. 2. 4.]]

 [[2. 1. 3.]
  [3. 2. 4.]]], shape=(2, 2, 3), dtype=float32)