Tensorflow 如何预处理Mapdataset以适应模型输入?

Tensorflow 如何预处理Mapdataset以适应模型输入?,tensorflow,tfrecord,Tensorflow,Tfrecord,我使用由文本中的标签和字符串中的浮点向量组成的MapDataset。 以下是我阅读TFR记录内容的方式: def extract_data(tfrecord_ds): feature_description = { 'classes_text': tf.io.FixedLenFeature((), tf.string), 'data': tf.io.FixedLenFeature([], tf.string) } def _parse_data_

我使用由文本中的标签和字符串中的浮点向量组成的MapDataset。 以下是我阅读TFR记录内容的方式:

def extract_data(tfrecord_ds):
    feature_description = {
        'classes_text': tf.io.FixedLenFeature((), tf.string),
        'data': tf.io.FixedLenFeature([], tf.string)
    }

def _parse_data_function(example_proto):
    return tf.compat.v1.parse_single_example(example_proto, feature_description)
parsed_dataset = tfrecord_ds.map(_parse_data_function)

dataset = parsed_dataset.cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
return dataset
我想根据label.txt文件将label\u文本转换为int,并将数据字符串转换为浮点向量

我想使用这些数据来训练自定义模型,如下所示:

my_model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(1024), dtype=tf.float32,
                              name='input_embedding'),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(num_classes)
    ], name='audio_detector')
如何将MapDataset从(字符串、字符串)处理为(int、float_数组)以训练模型

编辑:

以下是我对数据进行编码的方式:

 features = {}
                                features['classes_text'] = tf.train.Feature(
                                    bytes_list=tf.train.BytesList(value=[audio_data_generator.label.encode()]))
                                bytes = embedding.numpy().tobytes()
                                features['data'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes]))
                                tf_example = tf.train.Example(features=tf.train.Features(feature=features))
                                writer.write(tf_example.SerializeToString())

使用
tf.train.FloatList
对嵌入进行编码更容易

写入TFR记录时,请使用:

features = {
  'classes_text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label.encode()])),
  'data': tf.train.Feature(float_list=tf.train.FloatList(value=embedding))
}
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
读取时,将嵌入大小指定给
tf.io.FixedLenFeature
,例如:

embedding_size = 10
feature_description = {
  'classes_text': tf.io.FixedLenFeature((), tf.string),
  'data': tf.io.FixedLenFeature([embedding_size], tf.float32)
}
要将标签文本转换为int,可以使用
tf.lookup.StaticVocabularyTable

# Assuming lable.txt contains a single label per line.
with open('label.txt', 'r') as fin:
  categories = [line.strip() for line in fin.readlines()]
init = tf.lookup.KeyValueTensorInitializer(
    keys=tf.constant(categories),
    values=tf.constant(list(range(len(categories))), dtype=tf.int64))
label_table = tf.lookup.StaticVocabularyTable(
   init,
   num_oov_buckets=1)

feature_description = {
  'classes_text': tf.io.FixedLenFeature((), tf.string),
  'data': tf.io.FixedLenFeature([embedding_size], tf.float32)
}

def _parse_data_function(example_proto):
  example = tf.compat.v1.parse_single_example(example_proto, feature_description)
  # Apply the label lookup.
  example['classes_text'] = label_table.lookup(example['classes_text'])
  return example

parsed_dataset = tfrecord_ds.map(_parse_data_function)

dataset = parsed_dataset.cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)
编辑 如果希望保持保存数据的方式,可以使用
np.frombuffer
将numpy向量转换为from二进制Sting。不过,您必须将此代码包装在tf.function和tf.py_函数中

def decode_embedding(embedding_bytes):
  return np.frombuffer(embedding_bytes.numpy())
@tf.function()
def tf_decode_embedding(embedding_bytes):
  return tf.py_function(decode_embedding, inp=[embedding_bytes], Tout=tf.float32)

feature_description = {
        'classes_text': tf.io.FixedLenFeature((), tf.string),
        'data': tf.io.FixedLenFeature([], tf.string)
    }

def _parse_data_function(example_proto):
    example = tf.compat.v1.parse_single_example(example_proto, feature_description)
    example['classes_text'] = label_table.lookup(example['classes_text'])
    example['data'] = tf_decode_embedding(example['data'])
    return example

parsed_dataset = tfrecord_ds.map(_parse_data_function)

dataset = parsed_dataset.cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)

数据字符串是如何编码的?是逗号分隔的字符串吗?@Dimosthenis我修改了我的第一条消息来回答你的问题。谢谢。我不想使用浮动列表,因为它迫使我指定数据的大小。我的案例我的数据大小可能不同于tf记录。我可以像修改标签一样修改嵌入功能吗?如果嵌入大小没有固定大小,您可以填写零(并具有最大大小参数)或使用
tf.io.VarLenFeature
。我已经编辑了答案。另一个选项是在保存时使用
tf.io.serialize_tensor(embedding.numpy()
,在解析时使用
tf.io.parse_tensor(例如['data'],tf.float32)
,而不需要使用tf.py_函数。谢谢您的回答。非常清楚