Skip to content
Snippets Groups Projects
main_qwen.py 12.7 KiB
Newer Older
import re
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM
import datetime


def data_generator(dataset):
    for row in dataset:
        yield row

SKIP_INDEX = 2000
print(f'Starting skipping of dataset at {datetime.datetime.now()}', flush=True)
original_dataset = load_dataset("codeparrot/github-code", streaming=True, split="train", licenses=["mit", "isc"], languages = ['C#'], filter_languages=True, trust_remote_code=True)
original_dataset = original_dataset.skip(SKIP_INDEX)
print(f'Skipped the dataset for {SKIP_INDEX} samples at {datetime.datetime.now()}', flush=True)

start_model_loading = datetime.datetime.now()
model_name = "Qwen/Qwen2.5-Coder-32B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
end_model_loading = datetime.datetime.now()
print(f'Model: {model_name} loaded successfully in {(end_model_loading-start_model_loading).total_seconds()} seconds. Current Time: {datetime.datetime.now()}', flush=True)

BATCH_SIZE = 100
TOTAL_SAMPLES = 4500
NO_OF_BATCHES = TOTAL_SAMPLES // BATCH_SIZE
instance_number = 0
batch_number = 0
batch_dataset = []
is_dataset_created = False
DATASET_NAME = f'qwen-refined-code-{SKIP_INDEX}'
DATASET_NAME_WITH_USERNAME = f'atharva2721/{DATASET_NAME}'

MAX_MODEL_TOKENS = 32768

system_prompt = f"""
    You are Qwen, created by Alibaba Cloud. You are a C# expert.
    Your task is to refine the C# code enclosed within tags [C#] and [/C#]. 
    Refined code should be enclosed with tags [refined_C#] and [/refined_C#]. It should only contain executable code and no additional text.
    Summary of changes should be enclosed with [code_changes] and [/code_changes].
    You do not do anything more than user asks you do it.
    You do not generate any additional text.
    """

for example in original_dataset:
    user_prompt = f"""
    Refine C# Code Based on Clean Code and Design Principles.

    The goal of this task is to improve the quality, readability, and maintainability of the provided C# code. Apply the following principles step by step, ensuring that the resulting code is clean, modular, and adheres to object-oriented design best practices.

    You have been provided with principles and some examples wrapped inside [example_code] and [/example_code]to understand their meanings. Understand them to refine the code.
    
    1. Class Naming
    a. Use PascalCase for all class names (e.g., Invoice, Employee).
    b. Ensure class names are logical and represent a clear purpose:
        - A class name should be a noun that describes its role or entity.
        - Avoid vague or overly generic names like ManagerClass. Instead, use meaningful names like InvoiceManager or PayrollCalculator.
        - Ensure the name reflects what the class does or represents.

    Example:
    Before:
    [example_code]
    public class ManagerClass
    {{
        public void ManageInvoice()
        {{
            Console.WriteLine("Managing invoices..");
        }}
    }}
    [/example_code]

    After:
    [example_code]
    public class InvoiceManager
    {{
        public void Manage()
        {{
            Console.WriteLine("Managing invoices...");
        }}
    }}
    [/example_code]

    2. Property Naming
    a. Public Properties:
        - Use PascalCase (e.g., FirstName, Salary).
        - Ensure names are concise yet descriptive. Avoid redundant prefixes that repeat the class context.
        - Example: In a class Employee, name a property Id, not EmpId, since the class name already provides context.
    b. Private Fields:
        - Use a leading underscore `_` followed by camelCase (e.g., _firstName, _salary).
    c. Logical Naming:
        - Ensure property names clearly describe what they hold.
        - Avoid abbreviations unless widely understood (e.g., use DateOfBirth, not DOB).

    Example:
    Before:
    [example_code]
    public class Employee
    {{
        public string empName {{ get; set; }}
        private int emp_age;
    }}
    [/example_code]

    After:
    [example_code]
    public class Employee
    {{
        public string Name {{get; set; }}
        private int _age;

        public void SetAge(int age)
        {{
            _age = age;
        }}

        public int GetAge()
        {{
            return _age;
        }}
    }}
    [/example_code]

    3. Object Naming
    a. Use expressive names for instantiated objects to make the code self-explanatory:
        - Example: Employee employee is more meaningful than Employee e.
    b. Ensure names reflect their role or purpose in the code:
        - Example: If an object calculates totals, name it totalCalculator instead of calcObj.

    Example:
    Before:
    [example_code]
    Employee e = new Employee();
    e.Name = "John Doe";
    [/example_code]
    
    After:
    [example_code]
    Employee employee = new Employee();
    employee.Name = "John Doe";
    [/example_code]

    4. Method Naming
    a. Use PascalCase for all method names (e.g., CalculateSalary, GetEmployeeDetails).
    b. Ensure method names describe what the method does:
        - Example: Use GenerateReport instead of GenReport or Process.
    c. Avoid vague or overly generic names. Each method name should immediately convey its functionality.

    Example:
    Before:
    [example_code]
    public void GenReport()
    {{
        Console.WriteLine("Report Generated.");
    }}
    [/example_code]
    
    After:
    [example_code]
    public void GenerateMonthlyReport()
    {{
        Console.WriteLine("Monthly report has been generated.");
    }}
    [/example_code]

    5. Method Modularity
    a. Ensure that methods follow the Single Responsibility Principle:
        - Break down methods that perform multiple tasks into smaller, logically focused methods.
        - Example: A method ProcessPayroll that calculates totals, generates reports, and updates the database should be split into:
            1. CalculatePayrollTotals
            2. GeneratePayrollReport
            3. UpdatePayrollDatabase
        - Use meaningful names for all extracted methods to reflect their specific task.
    b. Consolidate repeated logic into reusable utility methods or helper functions.

    Example:
    Before:
    [example_code]
    public void ProcessPayroll()
    {{
        Calculate();
        Generate();
        Update();
    }}
    [/example_code]
    
    After:
    [example_code]
    public void ProcessPayroll()
    {{
        CalculatePayrollTotals();
        GeneratePayrollReport();
        UpdatePayrollDatabase();
    }}

    private void CalculatePayrollTotals()
    {{
        Console.WriteLine("Calculating payroll totals...");
    }}

    private void GeneratePayrollReport()
    {{
        Console.WriteLine("Generating payroll report...");
    }}

    private void UpdatePayrollDatabase()
    {{
        Console.WriteLine("Updating payroll database...");
    }}
    [/example_code]

    6. Single Responsibility Principle (SRP)
    a. Ensure each class is responsible for only one distinct task or purpose:
        - Example: If a class Invoice has methods for printing, calculating, and saving, split it into:
        - Invoice (business logic, such as calculating totals).
        - InvoicePrinter (handles formatting and output).
        - InvoiceRepository (handles database operations).
    b. Make classes precise and cohesive:
        - If a class like Employee has methods like ApproveTimeOff, consider whether subclasses such as Manager or Intern would better represent specialized roles.
        - Use inheritance to maintain logical separation of behavior.

    Example:
    Before:
    [example_code]
    public class Employee
    {{
        public void ApplyForVacation(){{/* ... */}}
        public void ApproveTimeOff() {{ /* ... */ }}
    }}

    public static void Main(string[] args){{
        Employee intern = new Employee();
        Employee manager = new Manager();

        intern.ApplyForVacation();
        manager.ApproveTimeOff();
    }}
    [/example_code]
    
    After:
    [example_code]
    public class Employee
    {{
        public void ApplyForVacation(){{/* ... */}}
    }}

    public class Intern: Emplyee
    {{
    
    }}

    public class Manager: Employee
    {{
        public void ApproveTimeOff() {{ /* ... */ }}
    }}

    public static void Main(string[] args)
    {{
        Intern intern = new Employee();
        Manager manager = new Manager();

        intern.ApplyForVacation();
        manager.ApproveTimeOff();
    }}
    [/example_code]

    7. Code Clean-Up
    a. Remove unused imports, variables, and comments to reduce clutter and improve readability.
    b. Ensure the code is free of dead or redundant logic.

    Example:
    Before:
    [example_code]
    using System;
    using System.Collections.Generic;
    // Unused import
    using System.Linq;

    public class Employee
    {{
        public string Name {{ get; set; }}
        // Commented-out code
        // public int Age {{ get; set; }}
    }}
    [/example_code]

    After:
    [example_code]
    using System;

    public class Employee
    {{
        public string Name {{ get; set; }}
    }}
    [/example_code]

    End Goal
    The refined code should:
    1. Adhere to C# naming conventions and clean code principles.
    2. Be modular and easy to maintain, with a clear separation of concerns.
    3. Follow the Single Responsibility Principle, ensuring each class and method has a well-defined purpose.
    4. Be expressive, making it easy for any developer to understand the code's intent at a glance.

    [C#]
    {example["code"]}
    [/C#]
    """

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    
    text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    no_of_input_tokens = model_inputs.input_ids.shape[1]
    max_output_tokens =  MAX_MODEL_TOKENS - no_of_input_tokens
        
    
    no_of_retries = 0
    should_try = True
    if max_output_tokens < no_of_input_tokens:
        print(f'Number of input tokens is very large [{no_of_input_tokens}]. Skippping this code refinement for {example["path"]}', flush=True)
        should_try = False

    while should_try:  
        try:
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=max_output_tokens
            )
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]

            output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

            code_pattern = r'\[refined_C#\](.*?)\[/refined_C#\]'
            summary_pattern = r'\[code_changes\](.*?)\[/code_changes\]'
            
            code_matches = re.search(code_pattern, output, re.DOTALL)
            summary_matches = re.search(summary_pattern, output, re.DOTALL)
            
            is_generation_okay = False
            if code_matches and summary_matches:
                refined_code = code_matches.group(1)
                summary = summary_matches.group(1)
                
                batch_dataset.append({'code': example["code"], 'refined code': refined_code, 'summary': summary})
                instance_number += 1
                should_try = False

            no_of_retries += 1
            if no_of_retries == 3:
                print(f'Could not clean the code. The final try output is{output}', flush=True)
                should_try = False  
        except Exception as error:
            print(f'Error is: {error}', flush=True)
    
    if instance_number == BATCH_SIZE:
        try:
            if not is_dataset_created:
                new_dataset = Dataset.from_generator(data_generator, gen_kwargs={"dataset": batch_dataset})
                new_dataset.push_to_hub(DATASET_NAME)
                is_dataset_created = True
            else:
                refined_code_dataset = load_dataset(DATASET_NAME_WITH_USERNAME, split="train", keep_in_memory=False)
                new_dataset = concatenate_datasets([refined_code_dataset, Dataset.from_generator(data_generator, gen_kwargs={"dataset": batch_dataset})])
                new_dataset.push_to_hub(DATASET_NAME)
        except Exception as push_error:
            print(f'Error during push to hub: {push_error}', flush=True)
            

        instance_number = 0
        batch_dataset = []
        batch_number += 1
        print(f'Pushed batch number {batch_number} to hub. Current Time {datetime.datetime.now()}', flush=True)
    
    if batch_number == NO_OF_BATCHES:
        break

print('Dataset generation completed.')