Warning: file_get_contents(/data/phpspider/zhask/data//catemap/9/java/361.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Java 如何训练OCR神经网络?_Java_Machine Learning_Neural Network_Handwriting Recognition - Fatal编程技术网

Java 如何训练OCR神经网络?

Java 如何训练OCR神经网络?,java,machine-learning,neural-network,handwriting-recognition,Java,Machine Learning,Neural Network,Handwriting Recognition,对于我的APCS最终项目,我正在制作一个应用程序: 允许用户在绘图面板上绘制数字 将每个笔划(由x-y坐标列表表示)缩放/转换为100x100 从缩放笔划生成图像 从该图像生成一个二进制2D数组(0表示白色,否则为1) 并将二进制数组传递给神经元对象进行字符识别 以下类别代表神经元: import java.awt.*; import java.util.*; import java.io.*; public class Neuron { private double[][] we

对于我的APCS最终项目,我正在制作一个应用程序:

  • 允许用户在绘图面板上绘制数字
  • 将每个笔划(由x-y坐标列表表示)缩放/转换为100x100
  • 从缩放笔划生成图像
  • 从该图像生成一个二进制2D数组(0表示白色,否则为1)
  • 并将二进制数组传递给神经元对象进行字符识别
以下类别代表神经元:

import java.awt.*;
import java.util.*;
import java.io.*;

public class Neuron
{
    private double[][] weights;
    public static double LEARNING_RATE = 0.01;

    /**
     *Initialize weights
     *Assign random double values to weights
     */
    public Neuron(int r, int c)
    {
        weights = new double[r][c];

        PrintWriter printer = null;
        try
        {
            printer = new PrintWriter("training.txt");
        }
        catch (FileNotFoundException e) {};
        for (int i = 0; i < weights.length; i++)
        {
            for (int j = 0; j < weights[i].length; j++)
            {
                weights[i][j] = 2 * Math.random() - 1; //Generates random number between -1 and 1
                if (j < weights[i].length - 1)
                    printer.print(weights[i][j] + " ");
                else
                    printer.print(weights[i][j]);
            }
            printer.println();
        }
        printer.close();
    }

    public Neuron(String fileName)
    {
        File data = new File(fileName);
        Scanner input = null;
        try
        {
            input = new Scanner(data);
        }
        catch (FileNotFoundException e)
        {
            System.out.println("Error: could not open " + fileName);
            System.exit(1);
        }

        int r = Drawing.DEF_HEIGHT, c = Drawing.DEF_WIDTH;
        weights = new double[r][c];

        int i = 0, j = 0;
        while (input.hasNext())
        {
            weights[i][j] = input.nextDouble();
            j++;
            if (j > weights[i].length - 1)
            {
                i++;
                j = 0;
            }
        }

        for (double[] a : weights)
            System.out.println(Arrays.toString(a));

    }

    /**
     *1. Initialize a sum variable
     *2. Multiply each index of weights by each index of bin
     *3. Sum these values
     *4. Return the activated sum
     */
    public int feedforward(int[][] bin) //bin represents 2D array of binary values for a binary image
    {
        double sum = 0;
        for (int i = 0; i < weights.length; i++)
        {
            for (int j = 0; j < weights[i].length; j++)
                sum += weights[i][j] * bin[i][j];
        }
        return activate(sum);
    }

    /**
     *1. Generate a sigmoid (logistic) value from a sum
     *2. "Digitize" the sigmoid value
     *3. Return the digitized value, which corresponds to a number
     */
    public int activate(double n)
    {
        double sig = 1.0/(1+Math.exp(-1*n));
        int digitized = 0;

        if (sig < 0.1)
            digitized = 0;
        else if (sig >= 0.1 && sig < 0.2)
            digitized = 1;
        else if (sig >= 0.2 && sig < 0.3)
            digitized = 2;
        else if (sig >= 0.3 && sig < 0.4)
            digitized = 3;
        else if (sig >= 0.4 && sig < 0.5)
            digitized = 4;
        else if (sig >= 0.5 && sig < 0.6)
            digitized = 5;
        else if (sig >= 0.6 && sig < 0.7)
            digitized = 6;
        else if (sig >= 0.7 && sig < 0.8)
            digitized = 7;
        else if (sig >= 0.8 && sig < 0.9)
            digitized = 8;
        else if (sig >= 0.9)
            digitized = 9;

        System.out.println("Sigmoid value: " + sig + "\nDigitized value: " + digitized);
        return digitized;
    }

    /**
     * 1. Provide inputs and "known" answer
     * 2. Guess according to the inputs using feedforward(inputs)
     * 3. Compute the error
     * 4. Adjust all weights according to the error and learning rate
     */
    public void train(int[][] bin, int desired)
    {
        int guess = feedforward(bin);
        int error = desired-guess;

        for (int i = 0; i < weights.length; i++)
        {
            for (int j = 0; j < weights[i].length; j++)
                weights[i][j] += LEARNING_RATE * error * bin[i][j];
        }
    }

}
import java.awt.*;
导入java.util.*;
导入java.io.*;
公共类神经元
{
私人双[]权重;
公共静态双学习率=0.01;
/**
*初始化权重
*将随机双精度值指定给权重
*/
公共神经元(intr,intc)
{
权重=新的双[r][c];
PrintWriter打印机=空;
尝试
{
打印机=新的PrintWriter(“training.txt”);
}
catch(filenotfounde){};
对于(int i=0;i权重[i]。长度-1)
{
i++;
j=0;
}
}
对于(双[]a:重量)
System.out.println(Arrays.toString(a));
}
/**
*1.初始化sum变量
*2.将每个权重指数乘以每个仓位指数
*3.将这些值相加
*4.返回激活的总和
*/
公共int前馈(int[][]bin)//bin表示二进制图像的二维二进制值数组
{
双和=0;
对于(int i=0;i=0.1&&sig<0.2)
数字化=1;
否则如果(sig>=0.2&&sig<0.3)
数字化=2;
否则如果(sig>=0.3&&sig<0.4)
数字化=3;
否则如果(sig>=0.4&&sig<0.5)
数字化=4;
否则如果(sig>=0.5&&sig<0.6)
数字化=5;
否则如果(sig>=0.6&&sig<0.7)
数字化=6;
否则如果(sig>=0.7&&sig<0.8)
数字化=7;
否则如果(sig>=0.8&&sig<0.9)
数字化=8;
否则如果(sig>=0.9)
数字化=9;
System.out.println(“Sigmoid值:+sig+”\n数字化值:+Digitalized”);
返回数字化;
}
/**
*1.提供输入和“已知”答案
*2.使用前馈(输入)根据输入进行猜测
*3.计算误差
*4.根据误差和学习率调整所有权重
*/
公共无效列车(int[][]箱,需要int)
{
int guess=前馈(bin);
int error=期望猜测;
对于(int i=0;i
我使用不同的类来“训练”神经元。另一个类–TrainingConsole.java–基本上使用随机生成的组件获取“training.txt”,为其提供训练示例(图像-->二进制2D数组),并根据错误、学习率和bin的相应值调整权重:

   import java.awt.image.BufferedImage;
import java.io.*;
import java.util.Arrays;
import java.util.Scanner;

import javax.imageio.ImageIO;

public class TrainingConsole
{

    private File folder;
    private File data;

    public TrainingConsole(String dataFileName, String folderName)
    {
        data = new File(dataFileName);
        folder = new File(folderName);
    }

    public void changeFolder(String folderName)
    {
        folder = new File(folderName);
    }

    public void feedAll(int desired)
    {
        System.out.println(Arrays.toString(folder.listFiles()));
        for (int i = 1; i < folder.listFiles().length; i++) //To exclude folder
        {
            BufferedImage img = new BufferedImage(Drawing.DEF_WIDTH,Drawing.DEF_HEIGHT,BufferedImage.TYPE_INT_RGB);
            try
            {

                String name = folder.listFiles()[i].getName();
                if (name.substring(name.length()-4).equals(".png"))
                    img = ImageIO.read(folder.listFiles()[i]);
            }
            catch(IOException e)
            {System.out.println("Error?");}

            int[][] bin = new int[Drawing.DEF_WIDTH][Drawing.DEF_HEIGHT];

            if (img != null)
            {
                for (int y = 0; y < img.getHeight(); y++)
                {
                    for (int x = 0; x < img.getWidth(); x++)
                    {
                        int rgb = img.getRGB(x,y);
                        //System.out.println(rgb);
                        if (rgb == -1) //White
                            bin[y][x] = 0;
                        else
                            bin[y][x] = 1;
                    }
                }
                for (int[] a : bin)
                    System.out.println(Arrays.toString(a));
                train(bin,desired);
            }
        }
    }

     public void train(int[][] bin, int desired) {
         int guess = feedforward(bin);
         int error = desired - guess;

         Scanner input = null;
         try {
             input = new Scanner(data);
         } catch (FileNotFoundException e) {
             System.exit(1);
         }
         double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH];
         int i = 0, j = 0;
         while (input.hasNext() && i < Drawing.DEF_HEIGHT) {
             weights[i][j] = input.nextDouble();
             j++;
             if (j > weights[i].length - 1) {
                 i++;
                 j = 0;
             }
         }

         for (int k = 0; k < weights.length; k++) {
             for (int l = 0; l < weights[k].length; l++)
                 weights[k][l] += IMGNeuron.LEARNING_RATE * error * bin[k][l];
         }

         data = new File(data.getName());
         PrintWriter output = null;
         try {
             output = new PrintWriter(data);
         } catch (FileNotFoundException e) {
             System.out.println("Cannot find data");
         }
         for (int m = 0; m < weights.length; m++) {
             for (int n = 0; n < weights[m].length - 1; n++)
                 output.print(weights[m][n] + " ");
             output.print(weights[m][weights[m].length - 1]);
             output.println();
         }
         output.close();
     }

    public int feedforward(int[][] bin)
    {
        double sum = 0;

        Scanner input = null;
        try
        {
            input = new Scanner(data);
        }
        catch(FileNotFoundException e)
        {
            System.out.println("Could not locate data");
        }
        double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH];
        int i = 0, j = 0;
        while (i < Drawing.DEF_HEIGHT && j < Drawing.DEF_WIDTH)
        {
            //System.out.println("( " + i + " , " + j + " )");
            weights[i][j] = input.nextDouble();
            j++;
            if (j > weights[i].length - 1)
            {
                i++;
                j = 0;
            }
        }

        for (int m = 0; m < weights.length; m++)
        {
            for (int n = 0; n < weights[m].length; n++)
                sum += weights[m][n] * bin[m][n];
        }
        return activate(sum);
    }

    public int activate(double n)
    {
        double sig = 1.0/(1+Math.exp(-1*n));
        int digitized = 0;

        if (sig < 0.1)
            digitized = 0;
        else if (sig >= 0.1 && sig < 0.2)
            digitized = 1;
        else if (sig >= 0.2 && sig < 0.3)
            digitized = 2;
        else if (sig >= 0.3 && sig < 0.4)
            digitized = 3;
        else if (sig >= 0.4 && sig < 0.5)
            digitized = 4;
        else if (sig >= 0.5 && sig < 0.6)
            digitized = 5;
        else if (sig >= 0.6 && sig < 0.7)
            digitized = 6;
        else if (sig >= 0.7 && sig < 0.8)
            digitized = 7;
        else if (sig >= 0.8 && sig < 0.9)
            digitized = 8;
        else if (sig >= 0.9)
            digitized = 9;

        return digitized;
    }

    public static void main(String[] args)
    {
        Scanner input = new Scanner(System.in);
        TrainingConsole trainer = new TrainingConsole("training.txt","Training_000");

        System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------");
        System.out.println("Training Console");
        System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------");

        for (int i = 0; i <= 9; i++) {
            //System.out.print("Folder with training data for desired = " + i + ", or enter \"skip\" to skip: ");
            //String folderName = input.nextLine().trim();
            String folderName = "Training_00" + i;
            //System.out.println(folderName);
            if (!folderName.toLowerCase().equals("skip"))
            {
                trainer.changeFolder(folderName);
//              System.out.print("Press enter to run: ");
//              String noReason = input.nextLine();
                trainer.feedAll(i);
            }
            System.out.println("----------------------------------------------------------------------------------------------------ava----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------");
        }
    }

}
导入java.awt.image.buffereImage;
导入java.io.*;
导入java.util.array;
导入java.util.Scanner;
导入javax.imageio.imageio;
公共类培训控制台
{
私人文件夹;
私有文件数据;
公共培训控制台(字符串数据文件名、字符串文件夹名)
{
数据=新文件(数据文件名);
文件夹=新文件(folderName);
}
公用void changeFolder(字符串folderName)
{
文件夹=新文件(folderName);
}
公共无效feedAll(需要整数)
{
System.out.println(Arrays.toString(folder.listFiles());
用于(int i=1;i