﻿// -----------------------------------------------------------------------
// <copyright file="CreateFirewall.cs" company="Microsoft IT">
// TODO: Update copyright text.
// </copyright>
// -----------------------------------------------------------------------

namespace Microsoft.Cis.Security.Tools.FirewallBenchmarks
{
    using System;
    using System.Collections.Generic;
    using System.Collections.Specialized;
    using System.Configuration;
    using System.IO;
    using System.Linq;
    using System.Net;
    using System.Text;

    public enum RuleMode { ALLOW_BLOCK, ALLOW, BLOCK };

    public enum GeneratioMode { RANDOM, REALISTIC };

    /// <summary>
    /// TODO: Update summary.
    /// </summary>
    public class CreateFirewall
    {
        UInt32 NumberOfRules;
        UInt32 NumberOfQueries;
        UInt32 BegAllowIPRange;
        UInt32 EndAllowIPRange;
        UInt32 BegPortRange;
        UInt32 EndPortRange;
        UInt32 IPRangeLen;
        UInt32 PortRangeLen;
        UInt32 PercentRangeLen;
        UInt32 BlockPerAllow;
        string LocalIP;
        string[] protocols;
        string[] directions;
        List<FirewallRule> Rules;
        List<FirewallQuery> Queries;

        public CreateFirewall(UInt32 numberOfRules)
        {
            NameValueCollection appsettings = ConfigurationManager.AppSettings;
            protocols = new string[] { "TCP", "UDP" };
            directions = new string[] { "IN", "OUT" };
            this.NumberOfRules = numberOfRules;
            this.NumberOfQueries = UInt32.Parse(appsettings["NoOfQueries"]);
            this.BegAllowIPRange = Conv.IP2Num(appsettings["BegAllowIPRange"]);
            this.EndAllowIPRange = Conv.IP2Num(appsettings["EndAllowIPRange"]);
            this.BegPortRange = UInt16.Parse(appsettings["BegPortRange"]);
            this.EndPortRange = UInt16.Parse(appsettings["EndPortRange"]);
            this.IPRangeLen = uint.Parse(appsettings["IPRangeLen"]);
            this.PortRangeLen = uint.Parse(appsettings["PortRangeLen"]); 
            this.PercentRangeLen = uint.Parse(appsettings["PercentRangeLen"]); 
            this.BlockPerAllow = uint.Parse(appsettings["BlockPerAllow"]); 
            this.Rules = new List<FirewallRule>();
            this.Queries = new List<FirewallQuery>();
            this.LocalIP = Utility.ChooseRandomIP(this.BegAllowIPRange, this.EndAllowIPRange);
        }

        public void CreateQueries()
        {
            for (int i = 0; i < this.NumberOfQueries; i++)
            {
                string iprange = Utility.ChooseRandomIP(this.BegAllowIPRange, this.EndAllowIPRange, (this.IPRangeLen*this.PercentRangeLen/100));
                string dir = Utility.ChooseRandomStr(directions);
                this.Queries.Add(new FirewallQuery(dir, "*", "*", iprange, "*", "*"));
            }
        }

        public void CreateRules(GeneratioMode mode)
        {

            while (Rules.Count < this.NumberOfRules)
            {
                switch (mode)
                {
                    case GeneratioMode.RANDOM:
                        {
                            if (Utility.CoinToss())
                            {
                                this.Rules.Add(CreateAllowRule(false));
                            }
                            else
                            {
                                this.Rules.Add(CreateBlockRule(this.PercentRangeLen));
                            }

                            break;
                        }
                    case GeneratioMode.REALISTIC:
                        {
                            this.Rules.AddRange(CreateAllowBlockRule(this.BlockPerAllow, this.PercentRangeLen));

                            break;
                        }
                }
            }
        }

        public List<FirewallRule> CreateRule(RuleMode ruleMode)
        {
            List<FirewallRule> rules = new List<FirewallRule>();

            switch (ruleMode)
            {
                case RuleMode.ALLOW_BLOCK:
                    {
                        rules.AddRange(CreateAllowBlockRule(this.BlockPerAllow, this.PercentRangeLen));
                        break;
                    }
                case RuleMode.ALLOW:
                    {
                        rules.Add(CreateAllowRule(false));
                        break;
                    }
                case RuleMode.BLOCK:
                    {
                        rules.Add(CreateBlockRule(PercentRangeLen));
                        break;
                    }
            }

            return rules;
        }

        public List<FirewallRule> CreateAllowBlockRule(uint noBlockRules, uint percentLen)
        {
            List<FirewallRule> rules = new List<FirewallRule>();
            FirewallRule allowRule = CreateAllowRule(true);
            string[] remoteIPRange = SplitIPRange(allowRule);

            rules.Add(allowRule);

            for (int i = 0; i < noBlockRules; i++)
            {
                FirewallRule block = CreateBlockRule(Conv.IP2Num(remoteIPRange[0]), Conv.IP2Num(remoteIPRange[1]), percentLen, allowRule.Direction);
                rules.Add(block);
            }

            return rules;
        }

        public FirewallRule CreateAllowRule(bool wildcardProtocol)
        {
            string name = string.Format("{0}",Guid.NewGuid());
            string remoteIP = Utility.ChooseRandomIP(this.BegAllowIPRange, this.EndAllowIPRange, this.IPRangeLen);
            string protocol = string.Empty;
            string dir = Utility.ChooseRandomStr(directions);

            if (wildcardProtocol)
            {
                protocol = "*";
            }
            else
            {
                protocol = Utility.ChooseRandomStr(protocols);
            }

            FirewallRule rule = new FirewallRule(name, "ALLOW", dir, this.LocalIP, "*", remoteIP, "*", protocol); 

            return rule;
        }

        public FirewallRule CreateBlockRule(UInt32 percentRangeLen)
        {
            string name = string.Format("{0}", Guid.NewGuid());
            UInt32 rangeLenIP = (this.IPRangeLen * percentRangeLen) / 100;
            UInt32 rangeLenPort = (this.PortRangeLen * percentRangeLen) / 100;
            string remoteIP = Utility.ChooseRandomIP(this.BegAllowIPRange, this.EndAllowIPRange, rangeLenIP);
            string remotePort = Utility.ChooseRandomPort((ushort)this.BegPortRange, (ushort)this.EndPortRange, (ushort)rangeLenPort);
            string protocol = Utility.ChooseRandomStr(protocols);
            string dir = Utility.ChooseRandomStr(directions);

            FirewallRule rule = new FirewallRule(name, "BLOCK", dir, this.LocalIP, "*", remoteIP, remotePort, protocol);

            return rule;
        }

        public FirewallRule CreateBlockRule(UInt32 begIPRange, UInt32 endIPRange, UInt32 percentRangeLen, string dir)
        {
            string name = string.Format("{0}", Guid.NewGuid());
            UInt32 rangeLenIP = (this.IPRangeLen * percentRangeLen) / 100;
            UInt32 rangeLenPort = (this.PortRangeLen * percentRangeLen) / 100;
            string remoteIP = Utility.ChooseRandomIP(begIPRange, endIPRange, rangeLenIP);
            string remotePort = Utility.ChooseRandomPort((ushort)this.BegPortRange, (ushort)this.EndPortRange, (ushort)rangeLenPort);
            string protocol = Utility.ChooseRandomStr(protocols);
            //string dir = Utility.ChooseRandomStr(directions);

            FirewallRule rule = new FirewallRule(name, "BLOCK", dir, this.LocalIP, "*", remoteIP, remotePort, protocol);

            return rule;
        }

        public void WriteQueries(string filename)
        {
            List<string> queries = new List<string>();

            queries.Add(FirewallQuery.Header());

            foreach (FirewallQuery query in this.Queries)
            {
                queries.Add(query.ToString());
            }

            this.Write(filename, queries.ToArray<string>());
        }

        public void WriteFirewall(string filename)
        {
            List<string> rules = new List<string>();

            // Write header
            rules.Add(FirewallRule.Header());

            foreach(FirewallRule rule in this.Rules)
            {
                rules.Add(rule.ToString());
            }

            this.Write(filename, rules.ToArray<string>());
        }

        public void Write(string filename, string[] lines)
        {
            StreamWriter writer = new StreamWriter(filename);

            foreach (string line in lines)
            {
                writer.WriteLine(line);
            }

            writer.Close();
        }

        /// <summary>
        /// Split the local IP range mentioned in the rule into
        /// lower and higher par
        /// </summary>
        /// <param name="rule"></param>
        /// <returns></returns>
        private string[] SplitIPRange(FirewallRule rule)
        {
            if (!rule.RemoteAddress.Contains("-"))
            {
                return null;
            }

            string[] localIPRange = rule.RemoteAddress.Split('-');

            return localIPRange;
        }
    }
}
