Skip to content
Snippets Groups Projects
llama_finetuned_inference.py 3.51 KiB
Newer Older
Atharva Jadhav's avatar
Atharva Jadhav committed
from datasets import load_dataset, Dataset
import re
Atharva Jadhav's avatar
Atharva Jadhav committed
from unsloth import FastLanguageModel
Atharva Jadhav's avatar
Atharva Jadhav committed

def data_generator(dataset):
    for row in dataset:
        yield row

Atharva Jadhav's avatar
Atharva Jadhav committed
max_seq_length = 32768 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
Atharva Jadhav's avatar
Atharva Jadhav committed
    model_name = "atharva2721/llama_finetuned_model",
    #model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",
Atharva Jadhav's avatar
Atharva Jadhav committed
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

reference_dataset = load_dataset("atharva2721/codebud-test-dataset", split="train", trust_remote_code=True)
Atharva Jadhav's avatar
Atharva Jadhav committed
inference_output = []
code_no = 0
inferred_no = 0
failed_no = 0
Atharva Jadhav's avatar
Atharva Jadhav committed
for example in reference_dataset:
    code_no += 1
Atharva Jadhav's avatar
Atharva Jadhav committed
    content = f'''
Atharva Jadhav's avatar
Atharva Jadhav committed
    Refine the C# code enclosed within tags [C#] and [/C#]. 
    Provide the refined code enclosed within tags [refined_C#] and [/refined_C#]
    The summary of changes must be enclosed within tags [code_changes] and [/code_changes].
Atharva Jadhav's avatar
Atharva Jadhav committed
    
    [C#]
Atharva Jadhav's avatar
Atharva Jadhav committed
    {example["code"]}
Atharva Jadhav's avatar
Atharva Jadhav committed
    [/C#]
    '''
    messages = [
        {"role": "user", "content": content},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True, # Must add for generation
        return_tensors = "pt",
    ).to("cuda")

Atharva Jadhav's avatar
Atharva Jadhav committed
    should_retry = True
    retry_no = 0
    while should_retry:
        print(f'Trying {code_no} for {retry_no} time', flush=True)
Atharva Jadhav's avatar
Atharva Jadhav committed
        output_tensor = model.generate(input_ids = inputs,
                        max_length  = max_seq_length,
                        temperature = 0.6,
                        repetition_penalty = 1.1
Atharva Jadhav's avatar
Atharva Jadhav committed
                        )
        
        decoded = tokenizer.batch_decode(output_tensor)

        output = ""
        for text in decoded:
            output += text
        
        output = output.split('<|start_header_id|>assistant<|end_header_id|>')
Atharva Jadhav's avatar
Atharva Jadhav committed
        if len(output) == 2:
            output = output[1]
            print(output)
Atharva Jadhav's avatar
Atharva Jadhav committed
            
            code_pattern = r'\[refined_C#\](.*?)\[/refined_C#\]'
            summary_pattern = r'\[code_changes\](.*?)\[/code_changes\]'
            
            code_matches = re.search(code_pattern, output, re.DOTALL)
            summary_matches = re.search(summary_pattern, output, re.DOTALL)

            if code_matches and summary_matches:
                refined_code = code_matches.group(1)
                summary = summary_matches.group(1)
                inference_output.append({'code': example["code"], 'finetuned inference':refined_code, 'finetuned summary': summary, 'reference inference': example["refined code"], 'reference summary': example["summary"]})
Atharva Jadhav's avatar
Atharva Jadhav committed
                print(f'Code no. {code_no} refined successfully', flush=True)
                should_retry = False
                inferred_no += 1
            
        if retry_no == 2:
Atharva Jadhav's avatar
Atharva Jadhav committed
            should_retry = False
            print(f'Failed to refine code at {code_no}. Final try output: \n [failed_output]{output}[/failed_output]', flush=True)
            failed_no +=1

        retry_no += 1

new_dataset = Dataset.from_generator(data_generator, gen_kwargs={"dataset": inference_output})
new_dataset.push_to_hub('llama_inference_output')
print(f'Created and pushed total of {inferred_no} examples from total of {code_no} codes. Total failed inferences are {failed_no}', flush=True)