Python PyBrain-out=fnn.activateOnDataset(griddata)
我一直在使用神经网络对PyBrain教程中的图像进行分类: 它以png格式输入图像数据,每个图像都指定了一个特定的类 直到:Python PyBrain-out=fnn.activateOnDataset(griddata),python,image-processing,neural-network,pybrain,Python,Image Processing,Neural Network,Pybrain,我一直在使用神经网络对PyBrain教程中的图像进行分类: 它以png格式输入图像数据,每个图像都指定了一个特定的类 直到: out = fnn.activateOnDataset(griddata) 它返回的消息是:AssertionError:(3,2) 我很确定如何声明griddata数据集是个问题,但我不知道具体是什么 在教程版本上,它运行良好 我的代码: from pybrain.datasets import ClassificationDataSet fr
out = fnn.activateOnDataset(griddata)
它返回的消息是:AssertionError:(3,2)
我很确定如何声明griddata数据集是个问题,但我不知道具体是什么
在教程版本上,它运行良好
我的代码:
from pybrain.datasets import ClassificationDataSet
from pybrain.utilities import percentError
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
from pybrain.structure.modules import SoftmaxLayer
from pylab import ion, ioff, figure, draw, contourf, clf, show, hold, plot
from scipy import diag, arange, meshgrid, where
from numpy.random import multivariate_normal
import cv2
from pyroc import *
#Creates cover type array based on color of pixels in roadmap
coverType = [(255,225,104,3), #Road
(254,254,253,0), #Other
(254,254,254,3), #Road
(253,254,253,0),#Other
(253,225,158,0),#Other
] # have other cover type but sample amount included
coverTypes = len(coverType)
print coverTypes #to count
#Creates dataset
alldata = ClassificationDataSet(3,1,nb_classes=10)
"""Classifies Roadmap Sub-Images by type and loads matching Satellite Sub-Image
with classification into dataset."""
for eachFile in glob.glob('Roadmap Sub-Images/*'):
image = Image.open(eachFile)
fileName = eachFile
newFileName = fileName.replace("Roadmap Sub-Images", "Satellite Sub-Images")
colors = image.convert('RGB').getcolors() #Finds all colors in image and their frequency
colors.sort() #Sorts colors in image by their frequency
colorMostFrequent = colors[-1][1] #Finds last element in array, the most frequent color
for eachColor in range(1,151): #151 number of element in CoverType array
if colorMostFrequent[0] == coverType[eachColor][0] and colorMostFrequent[1] == coverType[eachColor][1] and colorMostFrequent[2] == coverType[eachColor][2]:
print newFileName #Check new route
image = cv2.imread(newFileName)
meanImage = cv2.mean(image) #Take average color
meanImageRGB = meanImage[:3] #Converts to RGB scale, excluding "alpha"
print meanImageRGB #Check RGB average colors
alldata.addSample(meanImageRGB,coverType[eachColor][3])
tstdata, trndata = alldata.splitWithProportion( 0.25 )
trndata._convertToOneOfMany( )
tstdata._convertToOneOfMany( )
fnn = buildNetwork( trndata.indim, 5, trndata.outdim, outclass=SoftmaxLayer )
trainer = BackpropTrainer( fnn, dataset=trndata, momentum=0.1, verbose=True, weightdecay=0.01)
ticks = arange(-3.,6.,0.2)
X, Y = meshgrid(ticks, ticks)
#I think every thing is good to here problem with the griddata dataset I think?
# need column vectors in dataset, not arrays
griddata = ClassificationDataSet(2,1, nb_classes=4)
for i in xrange(X.size):
griddata.addSample([X.ravel()[i],Y.ravel()[i]], [0])
griddata._convertToOneOfMany() # this is still needed to make the fnn feel comfy
for i in range(20):
trainer.trainEpochs( 1 )
trnresult = percentError( trainer.testOnClassData(),
trndata['class'] )
tstresult = percentError( trainer.testOnClassData(
dataset=tstdata ), tstdata['class'] )
print "epoch: %4d" % trainer.totalepochs, \
" train error: %5.2f%%" % trnresult, \
" test error: %5.2f%%" % tstresult
out = fnn.activateOnDataset(alldata)
out = out.argmax(axis=1) # the highest output activation gives the class
out = out.reshape(X.shape)
figure(1)
ioff() # interactive graphics off
clf() # clear the plot
hold(True) # overplot on
for c in [0,1,2]:
here, _ = where(tstdata['class']==c)
plot(tstdata['input'][here,0],tstdata['input'][here,1],'o')
if out.max()!=out.min(): # safety check against flat field
contourf(X, Y, out) # plot the contour
ion() # interactive graphics on
draw() # update the plot
ioff()
show()
我相信这与初始数据集的维度与griddata的维度不一致有关
alldata=ClassificationDataSet(3,1,nb\u classes=10)
griddata=ClassificationDataSet(2,1,nb\u classes=4)
它们都应该是3,1。然而,当我调整这个时,我的代码在稍后的阶段失败了,所以我对此也很好奇