SGD

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace ConsoleApp4
{
    class Program
    {
        static void Main(string[] args)
        {
            List<float[]> inputs_x = new List<float[]>();
            inputs_x.Add( new float[] { 0.9f, 0.6f});
            inputs_x.Add(new float[] { 2f, 2.5f } );
            inputs_x.Add(new float[] { 2.6f, 2.3f });
            inputs_x.Add(new float[] { 2.7f, 1.9f });

            List<float> inputs_y = new List<float>();
            inputs_y.Add( 2.5f);
            inputs_y.Add( 2.5f);
            inputs_y.Add( 3.5f);
            inputs_y.Add( 4.2f);

            float[] weights = new float[3];
            for (var i= 0;i < weights.Length;i++)
                weights[i] = (float)new Random().NextDouble();

            int epoch = 30000;
            float epsilon =0.00001f;
            float lr = 0.01f;

            float lastCost=0;

            for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
            {
                //隨機獲取input
                var batch = GetRandomBatch(inputs_x, inputs_y, 2);

                float[] weights_in_poch = new float[weights.Length];

                foreach (var x_y in batch)
                {
                    var x1 = x_y.Item1.First();
                    var x2 = x_y.Item1.Skip(1).Take(1).First();
                    var target_y = x_y.Item2;

                    float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);

                    weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
                    weights_in_poch[1] +=  diffWithTargetY * dy_theta1(x1, x2);
                    weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
                }

                for(var i=0;i<weights.Length;i++)
                    weights[i] += lr * weights_in_poch[i];

                float totalErrorCost = 0f;
                foreach (var x_y in batch)
                {
                    var x1 = x_y.Item1.First();
                    var x2 = x_y.Item1.Skip(1).Take(1).First();
                    var target_y = x_y.Item2;

                    float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
                    totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2;
                }

                float cost = totalErrorCost / batch.Count;

                if (System.Math.Abs(cost - lastCost) <= epsilon)
                {
                    Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
                    Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
                    Console.WriteLine(string.Format("MSE {0}", cost));
                    break;
                }

                lastCost = cost;

                if (epoch_i % 100 == 0|| epoch_i==epoch)
                {
                    Console.WriteLine(string.Format("MSE {0}", cost));
                }
            }

            print(weights[1], weights[2], weights[0]);

            Console.ReadLine();
        }

        private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
        {
            List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>();

            System.Random rnd = new Random((int)DateTime.Now.Ticks);

            int count = 0;
            while (count<maxCount)
            {
                int rndIndex = rnd.Next(inputs_x.Count);
                var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
                lst.Add(item);
                count++;
            }

            return lst;
        }

        private static void print(float theta1, float theta2, float b)
        {
            Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b));
        }
        private static float fun(float x1, float x2, float theta1, float theta2, float b)
        {
            return theta1 * x1 + theta2 * x2 + b;
        }
        private static float dy_theta1(float x1, float x2)
        {
            return x1;
        }

        private static float dy_theta2(float x1, float x2)
        {
            return x2;
        }

        private static float dy_b(float x1, float x2)
        {
            return 1;
        }
    }
}
相關文章
相關標籤/搜索