Python 如何找出给定的内置tensorflow函数接受哪些数据类型的张量?

Python 如何找出给定的内置tensorflow函数接受哪些数据类型的张量?,python,tensorflow,Python,Tensorflow,我正在处理一个使用复数的tensorflow项目,所以我经常需要在复杂输入上应用内置函数。那么,我如何检查哪些tensorflow函数接受一个复杂的参数作为输入呢 比如说,, 当我尝试使用函数tf.math.scalar_mul()时,如下所示- ... self.scalar = tf.Variable(3, tf.int16) output = tf.math.scalar_mul(x, self.scalar) ... 它会产生以下错误- ValueError: Tensor conve

我正在处理一个使用复数的tensorflow项目,所以我经常需要在复杂输入上应用内置函数。那么,我如何检查哪些tensorflow函数接受一个复杂的参数作为输入呢

比如说,, 当我尝试使用函数tf.math.scalar_mul()时,如下所示-

...
self.scalar = tf.Variable(3, tf.int16)
output = tf.math.scalar_mul(x, self.scalar)
...
它会产生以下错误-

ValueError: Tensor conversion requested dtype int32 for Tensor with dtype complex64: 'Tensor("fourier__conv2d_5/mul:0", shape=(?, 28, 28, 17), dtype=complex64)'
我觉得这可能是由于tf.math.scalar_mul()不接受复杂的输入。我说的对吗?如果不对,可能是什么错误。(我尝试使用tf函数而不是基本的python函数,因为我认为在GPU上运行时它可能会带来好处)


提前感谢您的帮助。

您可以找到这一点,但结果将以操作和内核的形式给出,它们不能精确地映射到更高级别的Python函数。如果您不熟悉TensorFlow的体系结构,它是围绕“ops”的概念构建的,这只是对带有张量的操作的正式描述(例如,op“Add”接受两个值并输出第三个值)。张量流计算图由互连的op节点组成。Ops本身并不实现任何逻辑,它们只指定操作的名称和属性,包括可以应用该操作的数据类型。ops的实现由内核给出,内核是完成这项工作的实际代码片段。单个op可以有许多注册内核,这些内核使用不同的数据类型和/或不同的设备(CPU、GPU)进行操作

TensorFlow将所有这些信息作为不同的消息存储在“注册表”中。尽管它不是公共API的一部分,但实际上您可以查询这些注册表以获得满足某些条件的操作或内核列表。例如,您可以通过以下方式获得使用某些复杂类型的操作的所有操作:

import tensorflow as tf

def get_ops_with_dtypes(dtypes):
    from tensorflow.python.framework import ops
    valid_ops = []
    dtype_enums = set(dtype.as_datatype_enum for dtype in dtypes)
    reg_ops = ops.op_def_registry.get_registered_ops()
    for op in reg_ops.values():
        for attr in op.attr:
            if (attr.type == 'type' and
                any(t in dtype_enums for t in attr.allowed_values.list.type)):
                valid_ops.append(op)
                break
    # Sort by name for convenience
    return sorted(valid_ops, key=lambda op: op.name)

complex_dtypes = [tf.complex64, tf.complex128]
complex_ops = get_ops_with_dtypes(complex_dtypes)

# Print one op
print(complex_ops[0])
# name: "AccumulateNV2"
# input_arg {
#   name: "inputs"
#   type_attr: "T"
#   number_attr: "N"
# }
# output_arg {
#   name: "sum"
#   type_attr: "T"
# }
# attr {
#   name: "N"
#   type: "int"
#   has_minimum: true
#   minimum: 1
# }
# attr {
#   name: "T"
#   type: "type"
#   allowed_values {
#     list {
#       type: DT_FLOAT
#       type: DT_DOUBLE
#       type: DT_INT32
#       type: DT_UINT8
#       type: DT_INT16
#       type: DT_INT8
#       type: DT_COMPLEX64
#       type: DT_INT64
#       type: DT_QINT8
#       type: DT_QUINT8
#       type: DT_QINT32
#       type: DT_BFLOAT16
#       type: DT_UINT16
#       type: DT_COMPLEX128
#       type: DT_HALF
#       type: DT_UINT32
#       type: DT_UINT64
#     }
#   }
# }
# attr {
#   name: "shape"
#   type: "shape"
# }
# is_aggregate: true
# is_commutative: true

# Print op names
print(*(op.name for op in complex_ops), sep='\n')
# AccumulateNV2
# AccumulatorApplyGradient
# AccumulatorTakeGradient
# Acos
# Acosh
# Add
# AddN
# AddV2
# Angle
# ApplyAdaMax
# ...
这里,
complex_ops
中的元素是您可以检查的消息,以找出op的确切结构。在这种情况下,
get_ops_with_dtypes
只返回其
type
属性中具有给定数据类型之一的每个op,因此复杂值可以应用于其中一个输入或输出

另一种选择是直接查找处理您感兴趣的数据类型的内核。内核被存储为消息,这些消息不包含关于op的所有信息,但它们包含关于可以在其上运行的设备的信息,因此您还可以查询支持特定设备的内核

import tensorflow as tf

def get_kernels_with_dtypes(dtypes, device_type=None):
    from tensorflow.python.framework import kernels
    valid_kernels = []
    dtype_enums = set(dtype.as_datatype_enum for dtype in dtypes)
    reg_kernels = kernels.get_all_registered_kernels()
    for kernel in reg_kernels.kernel:
        if device_type and kernel.device_type != device_type:
            continue
        for const in kernel.constraint:
            if any(t in dtype_enums for t in const.allowed_values.list.type):
                valid_kernels.append(kernel)
                break
    # Sort by name for convenience
    return sorted(valid_kernels, key=lambda kernel: kernel.op)

complex_dtypes = [tf.complex64, tf.complex128]
complex_gpu_kernels = get_kernels_with_dtypes(complex_dtypes, device_type='GPU')

# Print one kernel
print(complex_gpu_kernels[0])
# op: "Add"
# device_type: "GPU"
# constraint {
#   name: "T"
#   allowed_values {
#     list {
#       type: DT_COMPLEX64
#     }
#   }
# }

# Print kernel op names
print(*(kernel.op for kernel in complex_gpu_kernels), sep='\n')
# Add
# Add
# AddN
# AddN
# AddV2
# AddV2
# Assign
# Assign
# AssignVariableOp
# AssignVariableOp
# ...
问题是,在Python中使用TensorFlow编程时,您从未真正直接使用ops或内核。Python函数接受您提供的参数,验证它们,并在图中生成一个或多个新的操作,通常将最后一个操作的输出值返回给您。因此,最终找出与您相关的操作/内核需要进行一些检查。例如,考虑下面的例子:

import tensorflow as tf

with tf.Graph().as_default():
    # Matrix multiplication: (2, 3) x (3, 4)
    tf.matmul([[1, 2, 3], [4, 5, 6]], [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
    # Print all op names and types
    all_ops = tf.get_default_graph().get_operations()
    print(*(f'Op name: {op.name}, Op type: {op.type}' for op in all_ops), sep='\n')
    # Op name: MatMul/a, Op type: Const
    # Op name: MatMul/b, Op type: Const
    # Op name: MatMul, Op type: MatMul

with tf.Graph().as_default():
    # Matrix multiplication: (1, 2, 3) x (1, 3, 4)
    tf.matmul([[[1, 2, 3], [4, 5, 6]]], [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
    # Print all op names and types
    all_ops = tf.get_default_graph().get_operations()
    print(*(f'Op name: {op.name}, Op type: {op.type}' for op in all_ops), sep='\n')
    # Op name: MatMul/a, Op type: Const
    # Op name: MatMul/b, Op type: Const
    # Op name: MatMul, Op type: BatchMatMul
在这里,相同的Python函数在每种情况下都产生了op类型。前两个操作在两种情况下都是
Const
,这是将给定列表转换为TensorFlow张量的结果,但第三个操作在一种情况下是
MatMul
,在另一种情况下是
BatchedMatMul
,因为在第二种情况下,输入有一个额外的初始维度

在任何情况下,如果您可以结合上述方法查找有关您感兴趣的一个op名称的所有op和内核信息:

def get_op_info(op_name):
    from tensorflow.python.framework import ops
    from tensorflow.python.framework import kernels
    reg_ops = ops.op_def_registry.get_registered_ops()
    op_def = reg_ops[op_name]
    op_kernels = list(kernels.get_registered_kernels_for_op(op_name).kernel)
    return op_def, op_kernels

# Get MatMul information
matmul_def, matmul_kernels = get_op_info('MatMul')

# Print op definition
print(matmul_def)
# name: "MatMul"
# input_arg {
#   name: "a"
#   type_attr: "T"
# }
# input_arg {
#   name: "b"
#   type_attr: "T"
# }
# output_arg {
#   name: "product"
#   type_attr: "T"
# }
# attr {
#   name: "transpose_a"
#   type: "bool"
#   default_value {
#     b: false
#   }
# }
# attr {
#   name: "transpose_b"
#   type: "bool"
#   default_value {
#     b: false
#   }
# }
# attr {
#   name: "T"
#   type: "type"
#   allowed_values {
#     list {
#       type: DT_BFLOAT16
#       type: DT_HALF
#       type: DT_FLOAT
#       type: DT_DOUBLE
#       type: DT_INT32
#       type: DT_COMPLEX64
#       type: DT_COMPLEX128
#     }
#   }
# }

# Total number of matrix multiplication kernels
print(len(matmul_kernels))
# 24

# Print one kernel definition
print(matmul_kernels[0])
# op: "MatMul"
# device_type: "CPU"
# constraint {
#   name: "T"
#   allowed_values {
#     list {
#       type: DT_FLOAT
#     }
#   }
# }

非常感谢您用示例对我进行了清晰的解释,对于我这个tensorflow的初学者来说,这真的很好,也很容易理解!