tensorflow训练手写数字并识别手写数字的全过程代码

代码语言:html

所属分类:其他

代码描述:tensorflow训练手写数字并识别手写数字的全过程代码

代码标签: 数字 识别 手写 数字 全过程

下面为部分代码预览,完整代码请点击下载或在bfwstudio webide中打开

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <style>
        #container {
            width: 300px;
            /*text-align: center;*/
            /*display: flex;*/
        }

        .canvas-container {}
        .button-container {
            display: flex;
            /*justify-content: center;*/
        }
        .btn {
            width: 140px;
        }
    </style>

</head>
<body>
    <div class="tfjs-example-container">
        <section class='title-area'>
            <h1>TensorFlow.js:手写数字识别</h1>
            <p class='subtitle'>
                使用tf.layers来训练MNIST手写数据并书别手写数据
                api.
            </p>
        </section>
        <section>
            <p class='section-head'>
                描述
            </p>
            <p>
                此示例允许您使用卷积神经网络(也称为ConvNet或CNN)或完全连接的神经网络(也称为DenseNet)来训练手写数字识别器。
                MNIST数据集用作训练数据。
            </p>
        </section>
        <section>
            <p class='section-head'>
                训练参数
            </p>
            <div>
                <label>Model Type:</label>
                <select id="model-type">
                    <option>Logistic</option>
                    <option>DenseNet</option>
                    <option>ConvNet</option>
                </select>
            </div>
            <div>
                <label># of training epochs:</label>
                <input id="train-epochs" value="3">
            </div>
            <button id="train">加载数据训练模型</button>
        </section>
        <section>
            <p class='section-head'>
                训练过程
            </p>
            <p id="status"></p>
            <p id="message"></p>
            <div id="stats">
                <div class="canvases">
                    <label id="loss-label"></label>
                    <div id="loss-canvas"></div>
                </div>
                <div class="canvases">
                    <label id="accuracy-label"></label>
                    <div id="accuracy-canvas"></div>
                </div>
            </div>
        </section>
        <section>
            <p class='section-head'>
                推理例子
            </p>
            <div id="images"></div>
        </section>
    </div>
  
    <script src='http://cdnjs.cloudflare.com/ajax/libs/fabric.js/1.4.0/fabric.min.js'></script>
    <script src='https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.3'></script>
    <script src='https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis'></script>
    <script>
      
        /**
        * @license
        * Copyright 2018 Google LLC. All Rights Reserved.
        * Licensed under the Apache License, Version 2.0 (the "License");
        * you may not use this file except in compliance with the License.
        * You may obtain a copy of the License at
        *
        * http://www.apache.org/licenses/LICENSE-2.0
        *
        * Unless required by applicable law or agreed to in writing, software
        * distributed under the License is distributed on an "AS IS" BASIS,
        * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        * See the License for the specific language governing permissions and
        * limitations under the License.
        * =============================================================================
        */

        const IMAGE_H = 28;
        const IMAGE_W = 28;
        const IMAGE_SIZE = IMAGE_H * IMAGE_W;
        const NUM_CLASSES = 10;
        const NUM_DATASET_ELEMENTS = 65000;

        const NUM_TRAIN_ELEMENTS = 55000;
        const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

        const MNIST_IMAGES_SPRITE_PATH =
        'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
        const MNIST_LABELS_PATH =
        'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';

        /**
        * A class that fetches the sprited MNIST dataset and provide data as
        * tf.Tensors.
        */
        class MnistData {
            constructor() {}

            async load() {
                // Make a request for the MNIST sprited image.
                const img = new Image();
                const canvas = document.createElement('canvas');
                const ctx = canvas.getContext('2d');
                const imgRequest = new Promise((resolve, reject) => {
                    img.crossOrigin = '';
                    img.onload = () => {
                        img.width = img.naturalWidth;
                        img.height = img.naturalHeight;

                        const datasetBytesBuffer =
                        new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

                        const chunkSize = 5000;
                        canvas.width = img.width;
                        canvas.height = chunkSize;

                        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
                            const datasetBytesView = new Float32Array(
                                datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
                                IMAGE_SIZE * chunkSize);
                            ctx.drawImage(
                                img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
                                chunkSize);

                            const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

                            for (let j = 0; j < imageData.data.length / 4; j++) {
                                // All channels hold an equal value since the image is grayscale, so
                                // just read the red channel.
                                datasetBytesView[j] = imageData.data[j * 4] / 255;
                            }
                        }
                        this.datasetImages = new Float32Array(datasetBytesBuffer);

                        resolve();
                    };
                    img.src = MNIST_IMAGES_SPRITE_PATH;
                });

                const labelsRequest = fetch(MNIST_LABELS_PATH);
                const [imgResponse, labelsResponse] =
                await Promise.all([imgRequest,
                    labelsRequest]);

                this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

                // Slice the the images and labels into train and test sets.
                this.trainImages =
                this.datasetImages.slice(0,
                    IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
                this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
                this.trainLabels =
                this.datasetLabels.slice(0,
                    NUM_CLASSES * NUM_TRAIN_ELEMENTS);
                this.testLabels =
                this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
            }

            /**
            * Get all training data as a data tensor and a labels tensor.
            *
            * @returns
            *   xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`.
            *   labels: The one-hot encoded labels tensor, of shape
            *     `[numTrainExamples, 10]`.
            */
            getTrainData() {
                const xs = tf.tensor4d(
                    this.trainImages,
                    [this.trainImages.length / IMAGE_SIZE,
                        IMAGE_H,
                        IMAGE_W,
                        1]);
                        
                        console.log(xs);
                const labels = tf.tensor2d(
                    this.trainLabels,
                    [this.trainLabels.length / NUM_CLASSES,
                        NUM_CLASSES]);
                return {
                    xs, labels
                };
            }

            /**
            * Get all test data as a data tensor a a labels tensor.
            *
            * @param {number} numExamples Optional number of examples to get. If not
            *     provided,
            *   all test examples will be returned.
            * @returns
            *   xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`.
            *   labels: The one-hot encoded labels tensor, of shape
            *     `[numTestExamples, 10]`.
            */
            getTestData(numExamples) {
                let xs = tf.tensor4d(
                    this.testImages,
                    [this.testImages.length / IMAGE_SIZE,
                        IMAGE_H,
                        IMAGE_W,
                        1]);
                let labels = tf.tensor2d(
                    this.testLabels,
                    [this.testLabels.length / NUM_CLASSES,
                        NUM_CLASSES]);

                if (numExamples != null) {
                    xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]);
                    labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
                }
                return {
                    xs,
                    labels
                };
            }
        }

        /**
        * @license
        * Copyright 2018 Google LLC. All Rights Reserved.
        * Licensed under the Apache License, Version 2.0 (the "License");
        * you may not use this file except in compliance with the License.
        * You may obtain a copy of the License at
        *
        * http://www.apache.org/licenses/LICENSE-2.0
        *
        * Unless required by applicable law or agreed to in writing, software
        * distributed under the License is distributed on an "AS IS" BASIS,
        * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        * See the License for the specific language governing permissions and
        * limitations under the License.
        * =============================================================================
        */

        //import * as tfvis from '@tensorflow/tfjs-vis';

        const statusElement = document.getElementById('status');
        const messageElement = document.getElementById('message');
        const imagesElement = document.getElementById('images');

        function logStatus(message) {
            statusElement.innerText = message;
        }

        function trainingLog(message) {
            messageElement.innerText = `${message}\n`;
            console.log(message);
        }

        function showTestResults(batch, predictions, labels) {
            const testExamples = batch.xs.shape[0];
            imagesElement.innerHTML = '';
            for (let i = 0; i < testExamples; i++) {
                const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]);

                const div = document.createElement('div');
                div.className = 'pred-container';

                const canvas = document.createElement('canvas');
                canvas.className = 'prediction-canvas';
                draw(image.flatten(), canvas);

                const pred = document.createElement('div');

                const prediction = predictions[i];
                const label = labels[i];
                const correct = prediction === label;

                pred.className = `pred ${(correct ? 'pred-correct': 'pred-incorrect')}`;
                pred.innerText = `pred: ${prediction}`;

                div.appendChild(pred);
                div.appendChild(canvas);

                imagesElement.appendChild(div);
            }
        }

        const lossLabelElement = document.getElementById('loss-label');
        const accuracyLabelElement = document.getElementById('accuracy-label');
        const lossValues = [[], []];
        function plotLoss(batch, loss, set) {
            const series = set === 'train' ? 0: 1;
            lossValues[series].push({
                x: batch, y: loss
            });
            const lossContainer = document.getElementById('loss-canvas');
            tfvis.render.linechart(
                {
                    values: lossValues, series: ['train', 'validation']}, lossContainer, {
                    xLabel: 'Batch #',
                    yLabel: 'Loss',
                    width: 400,
                    height: 300,
                });
            lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`;
        }

        const accuracyValues = [[], []];
        function plotAccuracy(batch, accuracy, set) {
            const accura.........完整代码请登录后点击上方下载按钮下载查看

网友评论0