Python 三维张量与四维张量张量流的比较

Python 三维张量与四维张量张量流的比较,python,tensorflow,Python,Tensorflow,我有下面的U-Net,我用它来分割灰度PNG图像 import cv2 import os from sklearn.utils import shuffle import tensorflow as tf import numpy as np OVERALLSIZE = int(float(input('Choose the number of images you want (<5635) : '))) PATH = input('give absolute path to ima

我有下面的U-Net,我用它来分割灰度PNG图像

import cv2
import os
from sklearn.utils import shuffle
import tensorflow as tf
import numpy as np


OVERALLSIZE = int(float(input('Choose the number of images you want (<5635) : ')))
PATH = input('give absolute path to image')
TESTSIZE = int(float(input('Choose the number of the data you want to use as test (<5635) : ')))

######################################################################################################################

images = [img for img in os.listdir(PATH + '/Xtrain') if img.endswith('png')]
# put random_state to 1
images = shuffle(images,random_state = 0)
masks = [name[:-4]+'_mask.png' for name in images]

images, masks = images[:OVERALLSIZE], masks[:OVERALLSIZE]
images_, masks_  = [cv2.imread(PATH + '/Xtrain/' + img, cv2.IMREAD_GRAYSCALE).astype(np.int) for img in images], \
                   [cv2.imread(PATH + '/ytrain/' + msk, cv2.IMREAD_GRAYSCALE).astype(np.int) for msk in masks]

######################################################################################################################


X_train, y_train, X_test, y_test = np.asarray(images_[TESTSIZE:])/255., \
                                   np.asarray(masks_[TESTSIZE:]), \
                                   np.asarray(images_[:TESTSIZE])/255., \
                                   np.asarray(masks_[:TESTSIZE])



x = tf.placeholder(tf.float32, shape=[None, 420, 580], name='x')
y_ = tf.placeholder(tf.float32, shape=[None, 420, 580], name='y_')

sess = tf.InteractiveSession()

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev = 0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def convoer(inputs, shape, flag):
    W = weight_variable(shape)
    b = bias_variable([shape[3]])

    temp = shape
    temp[2] = shape[3]

    Wa = weight_variable(temp)
    ba = bias_variable([shape[3]])

    conv = tf.nn.relu(conv2d(inputs, W) + b)
    conv = tf.nn.relu(conv2d(conv, Wa) + ba)
    pool = max_pool_2x2(conv)

    if flag: return pool
    elif not flag: return conv

def upconvoer(inputs, shape, height, width):
    W = weight_variable(shape)
    b = bias_variable([shape[3]])

    temp = shape
    temp[2] = shape[3]

    Wa = weight_variable(temp)
    ba = bias_variable([shape[3]])

    up = tf.image.resize_images(inputs, height, width)
    conv = tf.nn.relu(conv2d(up, W) + b)
    conv = tf.nn.relu(conv2d(conv, Wa) + ba)

    return conv

def conv2d(x, W):
    return tf.nn.conv2d(x,W,strides = [1,1,1,1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize = [1,2,2,1], strides = [1,2,2,1], padding='SAME')

def U():
    inputs = tf.reshape(x, [-1,420,580,1])

    pool1 = convoer(inputs, [3,3,1,32], True)

    pool2 = convoer(pool1, [3,3,32,64], True)
    pool3 = convoer(pool2, [3,3,64,128], True)
    pool4 = convoer(pool3, [3,3,128,256], True)
    conv5 = convoer(pool4, [3,3,256,512], False)

    conv6 = upconvoer(conv5, [3,3,512,256], 73, 53)
    conv7 = upconvoer(conv6, [3,3,256,128], 145, 105)
    conv8 = upconvoer(conv7, [3,3,128,64], 290, 210)
    conv9 = upconvoer(conv8, [3,3,64,32], 420, 580)

    W10 = weight_variable([1,1,32,1])
    b10 = bias_variable([1])

    conv10 = tf.nn.sigmoid(conv2d(conv9, W10) + b10)

    y = conv10

    return y

y = U()

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess.run(train_step, feed_dict={x: X_train, y_: y_train})

我尝试在3个不同的位置重塑y:在U函数中,在定义y的最开始位置,在交叉熵中,但没有一个起作用。

你能发布更多的上下文,这样我们就可以知道哪一个op特别是广播失败了吗?你有没有尝试通过移除第二个张量的最后一个维度来切片第二个张量?方法如下:使用挤压功能代替切片。
ValueError: Incompatible shapes for broadcasting: (?, 420, 580) and (?, 420, 580, 1)