tensorflow数据训练拟合的全过程

代码语言:html

所属分类:其他

代码描述:tensorflow数据训练拟合的全过程

代码标签: 拟合 全过程

下面为部分代码预览,完整代码请点击下载或在bfwstudio webide中打开

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">


</head>
<body translate="no">
    <div style="width:800px;">
        <canvas id="myChart" width="800" height="800"></canvas>
    </div>
    <button onclick="train()">训练模型一次</button>

    <script type="text/javascript" src="http://repo.bfw.wiki/bfwrepo/js/tf.min.js"></script>
    <script type="text/javascript" src="http://repo.bfw.wiki/bfwrepo/js/chart.js"></script>
    <script>
        const trainX = [
            3.3,
            4.4,
            5.5,
            6.71,
            6.93,
            4.168,
            9.779,
            6.182,
            7.59,
            2.167,
            7.042,
            10.791,
            5.313,
            7.997,
            5.654,
            9.27,
            3.1
        ];
        const trainY = [
            1.7,
            2.76,
            2.09,
            3.19,
            1.694,
            1.573,
            3.366,
            2.596,
            2.53,
            1.221,
            2.827,
            3.465,
            1.65,
            2.904,
            2.42,
            2.94,
            1.3
        ];

        const m = tf.variable(tf.scalar(Math.random()));
        const b = tf.variable(tf.scalar(Math.random()));

        function predict(x) {
            return tf.tidy(function() {
                return m.mul(x).add(b);
            });
        }

        function loss(prediction, labels) {
            //subtracts the two arrays & squares each element of the tensor then finds the mean.
            const error = prediction
            .sub(labels)
            .square()
            .mean();
            return error;
        }

        function train() {
            const learningRate = 0.005;
            const optimizer = tf.train.sgd(learningRate);

            optimizer.minimize(function() {
                const predsYs = predict(tf.tensor1d(trainX));
                console.log(predsYs);
                stepLoss = loss(predsYs, .........完整代码请登录后点击上方下载按钮下载查看

网友评论0