自定义编写层中的Tensorflow 2出错,该层在Tensorflow 1中正常工作(assert t.graph==cond_graph;AssertionError)

自定义编写层中的Tensorflow 2出错,该层在Tensorflow 1中正常工作(assert t.graph==cond_graph;AssertionError),tensorflow,graph,keras-layer,tensorflow2.0,tf.keras,Tensorflow,Graph,Keras Layer,Tensorflow2.0,Tf.keras,我已经编写了一个自定义编写层(在Keras中实现)。这一层在TF版本(1.13和1.15)中运行良好,我已经对自己进行了测试。另外,在TF2.0中,在导入Tensorflow模块后运行TF.compat.v1.disable_v2_behavior()line时,没有遇到错误。但是,如果我注释行tf.compat.v1.disable\u v2\u behavior()以保持TF2 behavior处于启用状态,则将引发一个断言错误。 这是自定义写入层: from tensorflow.kera

我已经编写了一个自定义编写层(在Keras中实现)。这一层在TF版本(1.13和1.15)中运行良好,我已经对自己进行了测试。另外,在TF2.0中,在导入Tensorflow模块后运行
TF.compat.v1.disable_v2_behavior()
line时,没有遇到错误。但是,如果我注释行
tf.compat.v1.disable\u v2\u behavior()
以保持TF2 behavior处于启用状态,则将引发一个断言错误。
这是自定义写入层:

from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer

import tensorflow as tf
import numpy as np
import math


class Conv1D_Sinc(tf.keras.layers.Layer):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10 ** (mel / 2595) - 1)

    def __init__(self,
                 filters,
                 kernel_size,
                 strides=1,
                 padding='same',
                 data_format='channels_last',
                 sample_rate=16000,
                 min_low_hz=50,
                 min_high_hz=50,
                 low_hz=30,
                 **kwargs):

        super(Conv1D_Sinc, self).__init__(**kwargs)

        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.data_format = data_format
        self.sample_rate = sample_rate
        self.min_low_hz = min_low_hz
        self.min_high_hz = min_high_hz
        self.low_hz = low_hz

        self.weights_init_for_build = True
        self.high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_high_hz)

        if self.padding == 'causal':
            if self.data_format != 'channels_last':
                raise ValueError('When using causal padding in `Conv1D`, '
                                 '`data_format` must be "channels_last" '
                                 '(temporal data).')
        try:
            kernel_size_number = self.kernel_size[0]
            if kernel_size_number % 2 == 0:
                raise ValueError('''Speed Issue:
                Length of filters ought to be odd (i.e, symmetric).''')
        except:
            if self.kernel_size % 2 == 0:
                raise ValueError('''Speed Issue:
                Length of filters ought to be odd (i.e, symmetric).''')


        if self.data_format == 'channels_first':
            self.channel_axis = 1
        else:
            self.channel_axis = -1


        # mel : (81,) np array
        mel = np.linspace(self.to_mel(self.low_hz), self.to_mel(self.high_hz),
                          self.filters + 1)
        # hz : (81,) np array
        hz = self.to_hz(mel)
        # Hamming window
        n_lin = np.linspace(0, self.kernel_size / 2 - 1, num=int(self.kernel_size / 2))  # computing
        # only half of the window
        self.window_ = np.array((0.54 - 0.46 * np.cos(2 * math.pi * n_lin / self.kernel_size)))
        # (kernel_size, 1)
        n = (self.kernel_size - 1) / 2.0
        self.n_ = np.array(2 * math.pi *
                           np.reshape(np.arange(-n, 0), (1, -1)) /
                           self.sample_rate)  # Due to symmetry, I only need half of the time axes

        # filter lower frequency band (filters, 1) - TRAINABLE VECTORS
        # hz[:,-1] : (80,)
        self.low_vec = np.array(np.reshape(hz[:-1], (-1, 1)))  # low_hz_ (80, 1)

        # filter frequency band (filters, 1) - TRAINABLE VECTORS
        # diff(hz) : (80,)
        self.high_vec = np.array(np.reshape(np.diff(hz), (-1, 1)))  # band_hz_ (80, 1)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        if input_shape[self.channel_axis] is None:
            raise ValueError('The channel dimension of the inputs '+\
                             'should be defined. Found `None`.')
        input_dim = input_shape[self.channel_axis]

        # kernel_size = 251
        # kernel_shape = (251,) + (1, 80)
        # kernel_shape = (251, 1, 80)
        self.kernel_shape = (self.kernel_size,) + (input_dim, self.filters)

        self.kernel_low = self.add_weight(shape=(self.filters, 1),
                                      initializer='zeros', # should be zero.
                                          # I will change it manually in call().
                                      name='kernel_low',
                                      regularizer=None,
                                      constraint=None)

        self.kernel_high = self.add_weight(shape=(self.filters, 1),
                                      initializer='zeros', # should be zero.
                                           # I will change it manually in call().
                                      name='kernel_high',
                                      regularizer=None,
                                      constraint=None)


    def call(self, inputs):
        # input of the call function is : (None, 3200, 1) === (?, 3200, 1)

        if self.weights_init_for_build:
            self.weights_init_for_build = False
            self.kernel_low = self.kernel_low + self.low_vec
            self.kernel_high = self.kernel_high + self.high_vec


        # low = self.min_low_hz + tf.math.abs(self.kernel_low)
        low = self.min_low_hz + K.abs(self.kernel_low)
        high = tf.clip_by_value(low + self.min_high_hz +
                                K.abs(self.kernel_high),
                                self.min_low_hz, self.sample_rate / 2)
        band = (high - low)[:, 0]

        f_times_t_low = K.dot(low, K.variable(self.n_))
        f_times_t_high = K.dot(high, K.variable(self.n_))

        band_pass_left = ((K.sin(f_times_t_high) - K.sin(f_times_t_low)) / (
        self.n_ / 2)) * self.window_  # Equivalent of Eq.4 of the reference paper
        # (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded
        # the sinc and simplified the terms. This way I avoid several useless computations.

        band_pass_center = 2 * K.reshape(band, (-1, 1))
        # band_pass_right = tf.reverse(band_pass_left, axis=[1])
        band_pass_right = K.reverse(band_pass_left, axes=1)
        band_pass = K.concatenate([band_pass_left, band_pass_center, band_pass_right], axis=1)
        band_pass = band_pass / (2 * band[:, None])

        self.kernel_all = K.reshape(band_pass, self.kernel_shape)


        # if self.rank == 1:
        outputs = K.conv1d(
            inputs,
            self.kernel_all,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format)
        return outputs



    def get_config(self):
        config = super(tf.keras.layers.Conv1D, self).get_config()
        if 'rank' in config.keys():
            config.pop('rank')
        return config

这就是错误:

Traceback (most recent call last):
  File "/media/deep/6BED7674319D2F8E/m-a-dastgheib/projects-code/keras/lightNet_v2/main_script_train.py", line 56, in <module>
    history_of_batch = ml_model.train_on_batch(np.expand_dims(input_batch, -1), label_of_batch)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 973, in train_on_batch
    class_weight=class_weight, reset_metrics=reset_metrics)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 264, in train_on_batch
    output_loss_metrics=model._output_loss_metrics)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 311, in train_on_batch
    output_loss_metrics=output_loss_metrics))
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_eager.py", line 268, in _process_single_batch
    grads = tape.gradient(scaled_total_loss, trainable_weights)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/eager/backprop.py", line 1014, in gradient
    unconnected_gradients=unconnected_gradients)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/eager/imperative_grad.py", line 76, in imperative_grad
    compat.as_str(unconnected_gradients.value))
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/eager/backprop.py", line 138, in _gradient_function
    return grad_fn(mock_op, *out_grads)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/ops/cond_v2.py", line 166, in _IfGrad
    true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
  File "/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow_core/python/ops/cond_v2.py", line 424, in _resolve_grad_inputs
    assert t.graph == cond_graph
AssertionError
回溯(最近一次呼叫最后一次):
文件“/media/deep/6BED7674319D2F8E/m-a-dastgheib/projects code/keras/lightNet_v2/main_script_train.py”,第56行,in
批次历史记录=ml批次模型。在批次上训练(np.扩展dims(输入批次,-1),标记批次)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow\u core/python/keras/engine/training.py”,第973行,在批处理的列中
类别权重=类别权重,重置度量=重置度量)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow\u core/python/keras/engine/training\u v2\u utils.py”,第264行,在train\u on\u批处理中
输出\损失\度量=模型。\输出\损失\度量)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow\u-core/python/keras/engine/training\u-eager.py”,第311行,在批量生产中
输出损失度量=输出损失度量)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow\u core/python/keras/engine/training\u eager.py”,第268行,在单批处理中
梯度=磁带梯度(标度总损耗、可训练重量)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow_core/python/eager/backprop.py”,第1014行,渐变格式
未连接的_渐变=未连接的_渐变)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow\u core/python/eager/祈使命令\u grad.py”,第76行,祈使命令\u grad
兼容as_str(未连接的梯度值))
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow\u core/python/eager/backprop.py”,第138行,在函数中
返回级(模拟级,*非模拟级)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow_core/python/ops/cond_v2.py”,第166行,在IfGrad中
真梯度输入=\u解析梯度输入(真梯度图、真梯度图)
文件“/home/deep/anaconda3/envs/tf2/lib/python3.7/site packages/tensorflow\u core/python/ops/cond\u v2.py”,第424行,在“解析”梯度输入中
断言t.graph==cond_图
断言错误
提前感谢您的帮助