tensorflow训练手写数字并识别手写数字的全过程代码
代码语言:html
所属分类:其他
代码描述:tensorflow训练手写数字并识别手写数字的全过程代码
下面为部分代码预览,完整代码请点击下载或在bfwstudio webide中打开
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <style> #container { width: 300px; /*text-align: center;*/ /*display: flex;*/ } .canvas-container {} .button-container { display: flex; /*justify-content: center;*/ } .btn { width: 140px; } </style> </head> <body> <div class="tfjs-example-container"> <section class='title-area'> <h1>TensorFlow.js:手写数字识别</h1> <p class='subtitle'> 使用tf.layers来训练MNIST手写数据并书别手写数据 api. </p> </section> <section> <p class='section-head'> 描述 </p> <p> 此示例允许您使用卷积神经网络(也称为ConvNet或CNN)或完全连接的神经网络(也称为DenseNet)来训练手写数字识别器。 MNIST数据集用作训练数据。 </p> </section> <section> <p class='section-head'> 训练参数 </p> <div> <label>Model Type:</label> <select id="model-type"> <option>Logistic</option> <option>DenseNet</option> <option>ConvNet</option> </select> </div> <div> <label># of training epochs:</label> <input id="train-epochs" value="3"> </div> <button id="train">加载数据训练模型</button> </section> <section> <p class='section-head'> 训练过程 </p> <p id="status"></p> <p id="message"></p> <div id="stats"> <div class="canvases"> <label id="loss-label"></label> <div id="loss-canvas"></div> </div> <div class="canvases"> <label id="accuracy-label"></label> <div id="accuracy-canvas"></div> </div> </div> </section> <section> <p class='section-head'> 推理例子 </p> <div id="images"></div> </section> </div> <script src='http://cdnjs.cloudflare.com/ajax/libs/fabric.js/1.4.0/fabric.min.js'></script> <script src='https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.3'></script> <script src='https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis'></script> <script> /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ const IMAGE_H = 28; const IMAGE_W = 28; const IMAGE_SIZE = IMAGE_H * IMAGE_W; const NUM_CLASSES = 10; const NUM_DATASET_ELEMENTS = 65000; const NUM_TRAIN_ELEMENTS = 55000; const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; const MNIST_IMAGES_SPRITE_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png'; const MNIST_LABELS_PATH = 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8'; /** * A class that fetches the sprited MNIST dataset and provide data as * tf.Tensors. */ class MnistData { constructor() {} async load() { // Make a request for the MNIST sprited image. const img = new Image(); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const imgRequest = new Promise((resolve, reject) => { img.crossOrigin = ''; img.onload = () => { img.width = img.naturalWidth; img.height = img.naturalHeight; const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { const datasetBytesView = new Float32Array( datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize); ctx.drawImage( img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let j = 0; j < imageData.data.length / 4; j++) { // All channels hold an equal value since the image is grayscale, so // just read the red channel. datasetBytesView[j] = imageData.data[j * 4] / 255; } } this.datasetImages = new Float32Array(datasetBytesBuffer); resolve(); }; img.src = MNIST_IMAGES_SPRITE_PATH; }); const labelsRequest = fetch(MNIST_LABELS_PATH); const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Slice the the images and labels into train and test sets. this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); } /** * Get all training data as a data tensor and a labels tensor. * * @returns * xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`. * labels: The one-hot encoded labels tensor, of shape * `[numTrainExamples, 10]`. */ getTrainData() { const xs = tf.tensor4d( this.trainImages, [this.trainImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]); console.log(xs); const labels = tf.tensor2d( this.trainLabels, [this.trainLabels.length / NUM_CLASSES, NUM_CLASSES]); return { xs, labels }; } /** * Get all test data as a data tensor a a labels tensor. * * @param {number} numExamples Optional number of examples to get. If not * provided, * all test examples will be returned. * @returns * xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`. * labels: The one-hot encoded labels tensor, of shape * `[numTestExamples, 10]`. */ getTestData(numExamples) { let xs = tf.tensor4d( this.testImages, [this.testImages.length / IMAGE_SIZE, IMAGE_H, IMAGE_W, 1]); let labels = tf.tensor2d( this.testLabels, [this.testLabels.length / NUM_CLASSES, NUM_CLASSES]); if (numExamples != null) { xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]); labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]); } return { xs, labels }; } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ //import * as tfvis from '@tensorflow/tfjs-vis'; const statusElement = document.getElementById('status'); const messageElement = document.getElementById('message'); const imagesElement = document.getElementById('images'); function logStatus(message) { statusElement.innerText = message; } function trainingLog(message) { messageElement.innerText = `${message}\n`; console.log(message); } function showTestResults(batch, predictions, labels) { const testExamples = batch.xs.shape[0]; imagesElement.innerHTML = ''; for (let i = 0; i < testExamples; i++) { const image = batch.xs.slice([i, 0], [1, batch.xs.shape[1]]); const div = document.createElement('div'); div.className = 'pred-container'; const canvas = document.createElement('canvas'); canvas.className = 'prediction-canvas'; draw(image.flatten(), canvas); const pred = document.createElement('div'); const prediction = predictions[i]; const label = labels[i]; const correct = prediction === label; pred.className = `pred ${(correct ? 'pred-correct': 'pred-incorrect')}`; pred.innerText = `pred: ${prediction}`; div.appendChild(pred); div.appendChild(canvas); imagesElement.appendChild(div); } } const lossLabelElement = document.getElementById('loss-label'); const accuracyLabelElement = document.getElementById('accuracy-label'); const lossValues = [[], []]; function plotLoss(batch, loss, set) { const series = set === 'train' ? 0: 1; lossValues[series].push({ x: batch, y: loss }); const lossContainer = document.getElementById('loss-canvas'); tfvis.render.linechart( { values: lossValues, series: ['train', 'validation']}, lossContainer, { xLabel: 'Batch #', yLabel: 'Loss', width: 400, height: 300, }); lossLabelElement.innerText = `last loss: ${loss.toFixed(3)}`; } const accuracyValues = [[], []]; function plotAccuracy(batch, accuracy, set) { const accura.........完整代码请登录后点击上方下载按钮下载查看
网友评论0