Skip to content
Snippets Groups Projects
Commit 1d668cd3 authored by Atharva Jadhav's avatar Atharva Jadhav
Browse files

Stabilize fine-tuning

parent bf928bca
No related branches found
No related tags found
No related merge requests found
......@@ -41,7 +41,7 @@ tokenizer = get_chat_template(
)
dataset = load_dataset("atharva2721/standardized-refined-train-aggregated", split = "train")
validation_dataset = load_dataset("atharva2721/standardized-refined-val-aggregated", split = "train")
validation_dataset = load_dataset("atharva2721/standardized-refined-val-test-aggregated", split = "train")
wandb.init(project="codebud")
......@@ -65,16 +65,16 @@ trainer = SFTTrainer(
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
eval_strategy="steps",
eval_steps=656,
eval_steps=410,
per_device_eval_batch_size = 1,
fp16_full_eval = not is_bfloat16_supported(),
bf16_full_eval = is_bfloat16_supported(),
logging_steps = 10,
save_steps = 656,
save_steps = 410,
optim = "paged_adamw_8bit", # Save more memory
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
remove_unused_columns=False,
output_dir = "outputs",
report_to = "wandb", # Use this for WandB etc
run_name = "run-name"
......@@ -117,10 +117,10 @@ print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
print(f'Pushing model and tokenizer at {datetime.datetime.now()}', flush=True)
model.save_pretrained("models/finetuned_model_with_eval") # Local saving
tokenizer.save_pretrained("models/finetuned_model_with_eval")
model.push_to_hub("finetuned_model_with_eval") # Online saving
tokenizer.push_to_hub("finetuned_model_with_eval") # Online saving
model.save_pretrained("models/finetuned_model_with_three_epochs_eval") # Local saving
tokenizer.save_pretrained("models/finetuned_model_with_three_epochs_eval")
model.push_to_hub("finetuned_model_with_three_epochs_eval") # Online saving
tokenizer.push_to_hub("finetuned_model_with_three_epochs_eval") # Online saving
wandb.finish()
print(f'Run complete at {datetime.datetime.now()}', flush=True)
\ No newline at end of file
......@@ -9,7 +9,7 @@
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=5
#SBATCH --gres=gpu:1
#SBATCH --time=12:00:00
#SBATCH --time=22:00:00
###------------------------------------------------------------------------------------------------------------------------------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment