from datasets import load_dataset
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template


def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

def format_to_conversations(examples):
  conversations = []
  codes = examples["code"]
  refined_codes = examples["refined code"]
  summaries = examples["summary"]
  for i in range(len(refined_codes)):
      user_content = f'''Refine the C# code enclosed within tags [C#] and [/C#].

      [C#]
      {codes[i]}
      [/C#]
      '''
      assistant_content = f'''
      [refined_C#]
      {refined_codes[i]}
      [/refined_C#]
      [code_changes]
      {summaries[i]}
      [/code_changes]
      '''
      conversation = []

      user_dict = {'content': user_content, 'role': 'user'}
      assistant_dict = {'content': assistant_content, 'role': 'assistant'}
      conversation.append(user_dict)
      conversation.append(assistant_dict)

      conversations.append(conversation)

  return { "conversations" : conversations }
pass

max_seq_length = 32768 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
)

dataset = load_dataset("atharva2721/refined-train-aggregated", split = "train")

dataset = dataset.map(format_to_conversations, batched = True,)
dataset = dataset.map(formatting_prompts_func, batched = True,)

dataset.push_to_hub('llama-standardized-refined-test-aggregated')
print('Dataset pushed to hub')