Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Keras应用-imagenet上的VGG16低精度_Python_Tensorflow_Keras_Google Colaboratory_Tensorflow Datasets - Fatal编程技术网

Python Keras应用-imagenet上的VGG16低精度

Python Keras应用-imagenet上的VGG16低精度,python,tensorflow,keras,google-colaboratory,tensorflow-datasets,Python,Tensorflow,Keras,Google Colaboratory,Tensorflow Datasets,我试图复制这里提到的VGG-16的性能: 但是,当我在来自tensorflow数据集的imagenet数据集上运行该模型时,我得到了较低的top5精度0.866 这是我的代码: import tensorflow_datasets as tfds import tensorflow as tf from tensorflow.keras import applications import tensorflow.keras.applications.vgg16 as vgg16 def sc

我试图复制这里提到的VGG-16的性能:

但是,当我在来自tensorflow数据集的imagenet数据集上运行该模型时,我得到了较低的top5精度0.866

这是我的代码:

import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras import applications
import tensorflow.keras.applications.vgg16 as vgg16

def scale16(image, label):
  i = image
  i = tf.cast(i, tf.float32)
  i = tf.image.resize(i, (224,224))
  i = vgg16.preprocess_input(i)
  return (i, label)

def batch_set(dataset, batch_size):
    return dataset.map(scale16) \
                  .shuffle(1000) \
                  .batch(batch_size) \
                  .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

def create_batched_datasets(map_fn, data_dir = "/content", batch_size = 64):
    datasets, info = tfds.load(name="imagenet2012", 
                            with_info=True, 
                            as_supervised=True, 
                            download=False, 
                            data_dir=data_dir
                            )
    train = batch_set(datasets['train'], batch_size)
    val = batch_set(datasets['validation'], batch_size)
    return train, val, info


train, test_dataset, info = create_batched_datasets(scale16)

model = vgg16.VGG16(weights='imagenet', include_top=True)

model.compile('sgd', 'categorical_crossentropy', 
              ['sparse_categorical_accuracy','sparse_top_k_categorical_accuracy'])

model.evaluate(test_dataset)

我错过了什么?我正在google colab上运行代码。

代码没有正确预处理图像。方法tf.image.resize()将图像缩小。但根据keras网站的说法,224x224x3的图像应该由中心裁剪创建。更改scale16()方法可以解决以下问题:


def resize_image(image, shape = (224,224)):
  target_width = shape[0]
  target_height = shape[1]
  initial_width = tf.shape(image)[0]
  initial_height = tf.shape(image)[1]
  im = image
  ratio = 0
  if(initial_width < initial_height):
    ratio = tf.cast(256 / initial_width, tf.float32)
    h = tf.cast(initial_height, tf.float32) * ratio
    im = tf.image.resize(im, (256, h), method="bicubic")
  else:
    ratio = tf.cast(256 / initial_height, tf.float32)
    w = tf.cast(initial_width, tf.float32) * ratio
    im = tf.image.resize(im, (w, 256), method="bicubic")
  width = tf.shape(im)[0]
  height = tf.shape(im)[1]
  startx = width//2 - (target_width//2)
  starty = height//2 - (target_height//2)
  im = tf.image.crop_to_bounding_box(im, startx, starty, target_width, target_height)
  return im

def scale16(image, label):
  i = image
  i = tf.cast(i, tf.float32)
  i = resize_image(i, (224,224))
  i = vgg16.preprocess_input(i)
  return (i, label)


def resize_图像(图像,形状=(224224)):
目标宽度=形状[0]
目标高度=形状[1]
初始宽度=tf.形状(图像)[0]
初始高度=tf.形状(图像)[1]
im=图像
比率=0
如果(初始宽度<初始高度):
比率=tf.cast(256/初始宽度,tf.float32)
h=tf.铸件(初始高度,tf.浮动32)*比率
im=tf.image.resize(im,(256,h),method=“双三次”)
其他:
比率=tf.cast(256/初始高度,tf.32)
w=tf.cast(初始宽度,tf.float32)*比率
im=tf.image.resize(im,(w,256),method=“双三次”)
宽度=tf.形状(im)[0]
高度=tf.形状(im)[1]
startx=宽度//2-(目标宽度//2)
starty=高度//2-(目标高度//2)
im=tf.image.crop\u to\u bounding\u box(im、startx、starty、target\u width、target\u height)
返回即时消息
def刻度16(图像、标签):
i=图像
i=tf.cast(i,tf.float32)
i=调整图像大小(i,(224))
i=vgg16.预处理输入(i)
退货(一、标签)

如果达到86.6%的准确率,您所做的有什么问题?您使用什么数据进行测试?