Skip to content
Snippets Groups Projects
llama_finetuned_inference.py 7.59 KiB
Newer Older
Atharva Jadhav's avatar
Atharva Jadhav committed
from unsloth import FastLanguageModel
max_seq_length = 32768 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    #model_name = "atharva2721/llama_finetuned_model",
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

code1 = """
using AllegianceForms.Engine;
using AllegianceForms.Engine.Ships;
using System.Drawing;
using System.Linq;
namespace AllegianceForms.Orders
{
    public class InterceptOrder : MoveOrder
    {
        private Ship _target;
        private bool _changeTarget;
        public InterceptOrder(
            StrategyGame game,
            Ship targetShip,
            int sectorId,
            bool changeTarget = false
        ) : base(game, sectorId)
        {
            OrderPen.Color = Color.LightGray;
            _target = targetShip;
            _changeTarget = changeTarget;
            OrderPosition = _target.CenterPoint;
        }
        public override void Update(Ship ship)
        {
            if (_target != null && _target.Active && _target.SectorId == ship.SectorId)
            {
                OrderPosition = _target.CenterPoint;
                OrderSectorId = _target.SectorId;
            }
            else if (_changeTarget)
            {
                var targets = _game.AllUnits
                    .Where(
                        _ =>
                            _.Active
                            && _.SectorId == ship.SectorId
                            && _.Alliance == _target.Alliance
                            && _.Type == _target.Type
                    )
                    .ToList();
                if (targets.Count == 0)
                {
                    OrderComplete = true;
                    return;
                }
                _target = StrategyGame.RandomItem(targets);
            }
            else
            {
                OrderComplete = true;
                return;
            }
            base.Update(ship);
            if (OrderComplete && ship.Type == EShipType.Lifepod)
            {
                ship.Dock(null);
            }
            else if (
                OrderComplete
                && _target != null
                && _target.Active
                && _target.SectorId == ship.SectorId
            )
            {
                OrderComplete = false;
            }
        }
        public override void Draw(Graphics g, PointF fromPos, int fromSectorId)
        {
            if (fromSectorId != _target.SectorId)
                return;
            OrderPosition = _target.CenterPoint;
            base.Draw(g, fromPos, fromSectorId);
        }
    }
}
"""

code2 = """
using System;
using System.Collections;
namespace hub
{
    public class logicproxy
    {
        public logicproxy(juggle.Ichannel ch)
        {
            _hub_call_logic = new caller.hub_call_logic(ch);
        }
        public void reg_logic_sucess_and_notify_hub_nominate()
        {
            _hub_call_logic.reg_logic_sucess_and_notify_hub_nominate(hub.name);
        }
        public void call_logic(String module_name, String func_name, params object[] argvs)
        {
            ArrayList _argvs = new ArrayList();
            foreach (var o in argvs)
            {
                _argvs.Add(o);
            }
            _hub_call_logic.hub_call_logic_mothed(module_name, func_name, _argvs);
        }
        private caller.hub_call_logic _hub_call_logic;
    }
}
"""

code3 = """
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Max_Number
{
    class Program
    {
        static void Main(string[] args)
        {
            int number = int.Parse(Console.ReadLine());
            double max = double.Parse(Console.ReadLine());
            for (int i = 1; i < number; i++)
            {
                double d = double.Parse(Console.ReadLine());
                if (d > max)
                    max = d;
            }
            Console.WriteLine(max);
        }
    }
}
"""

code4 = """
using System.Collections.Generic;
using System.Linq;
namespace Mirage.Urbanization.Simulation
{
    public class CityCategoryDefinition
    {
        public CityCategoryDefinition(string name, int minimumPopulation)
        {
            Name = name;
            MinimumPopulation = minimumPopulation;
        }
        public string Name { get; }
        public int MinimumPopulation { get; }
        public static CityCategoryDefinition GetForPopulation(int population)
        {
            return Definitions
                .Where(x => x.MinimumPopulation <= population)
                .OrderByDescending(x => x.MinimumPopulation)
                .First();
        }
        public static CityCategoryDefinition Village = new CityCategoryDefinition("Village", 0);
        private static readonly IReadOnlyCollection<CityCategoryDefinition> Definitions = new[]
        {
            Village,
            new CityCategoryDefinition("Town", 2000),
            new CityCategoryDefinition("City", 10000),
            new CityCategoryDefinition("Capital", 50000),
            new CityCategoryDefinition("Metropolis", 100000)
        }.ToList();
    }
}
"""

code5 = """
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Drawing.Imaging;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;
namespace TagUIWordAddIn
{
    public partial class SnapshotBar : Form
    {
        public SnapshotBar()
        {
            InitializeComponent();
        }
        private void button1_Click(object sender, EventArgs e)
        {
            this.Hide();
            Form1 f1 = new Form1();
            if (checkBoxDelay.Checked)
            {
                System.Threading.Thread.Sleep(5000);
                Form2 f2 = new Form2();
                f1.Owner = f2;
                f2.Show();
                f1.Show();
                f1.Closed += (s, args) =>
                {
                    this.Close();
                    f2.Close();
                };
            }
            else
            {
                f1.Show();
                f1.Closed += (s, args) =>
                {
                    this.Close();
                };
            }
            f1.Closed += (s, args) =>
            {
                this.Close();
            };
        }
    }
}
"""


codes = [code1, code2, code3, code4, code5]

for code in codes:
    print("******************Start of Generation******************")
    content = f'''
    Refine the C# code enclosed within tags [C#] and [/C#].
    Provide the refined code enclosed within tags [refined_C#] and [/refined_C#] and summary of changes enclosed within tags [code_changes] and [/code_changes].
    
    [C#]
    {code}
    [/C#]
    '''

    messages = [
        {"role": "user", "content": content},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True, # Must add for generation
        return_tensors = "pt",
    ).to("cuda")

    from transformers import TextStreamer
    text_streamer = TextStreamer(tokenizer, skip_prompt = True)
    _ = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 10000, temperature = 0.6)
    
    print("******************End of Generation******************")