Skip to content
Snippets Groups Projects
Commit ab7f20a0 authored by Atharva Jadhav's avatar Atharva Jadhav
Browse files

Add cosine similarity evaluation

parent d7b728a9
No related branches found
No related tags found
No related merge requests found
File added
import csv
from datasets import load_dataset
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
def create_dataframe(matrix, tokens):
doc_names = [f"doc_{i+1}" for i in range(len(matrix))]
df = pd.DataFrame(data=matrix, index=doc_names, columns=tokens)
return df
reference_dataset = load_dataset("atharva2721/qwen_inference_output_complete", split="train", trust_remote_code=True)
code_number = 0
with open(f'eval_reports/cosine-similarity-qwen-base-responses-evaluation-pass.csv', 'w') as f:
fieldnames = ['code number', 'cosine similarity score (CountVectorizer)', 'cosine similarity score (TfidfVectorizer)']
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for example in reference_dataset:
code_number +=1
doc_1 = f"{example['base inference']}"
doc_2 = f"{example['reference inference']}"
data = [doc_1, doc_2]
count_vectorizer = CountVectorizer()
vector_matrix_count = count_vectorizer.fit_transform(data)
count_tokens = count_vectorizer.get_feature_names_out()
df_count_tokens = create_dataframe(vector_matrix_count.toarray(), count_tokens)
cosine_similarity_matrix_count = cosine_similarity(vector_matrix_count)
df_cosine_count = create_dataframe(cosine_similarity_matrix_count, ["doc_1","doc_2"])
print(f"Cosine Similarity (CountVectorizer) for code number {code_number} :: {df_cosine_count}")
#-----------------------------------------------------------------------------------------------
tfidf_vectorizer = TfidfVectorizer()
vector_matrix_tfidf = tfidf_vectorizer.fit_transform(data)
tfidf_tokens = tfidf_vectorizer.get_feature_names_out()
df_tfidf_tokens = create_dataframe(vector_matrix_tfidf.toarray(), tfidf_tokens)
cosine_similarity_matrix_tfidf = cosine_similarity(vector_matrix_tfidf)
df_cosine_tfidf = create_dataframe(cosine_similarity_matrix_tfidf, ["doc_1","doc_2"])
print(f"Cosine Similarity (TfidfVectorizer) for code number {code_number} :: {df_cosine_tfidf}")
writer.writerow({'code number': code_number, 'cosine similarity score (CountVectorizer)': df_cosine_count.loc["doc_1", "doc_2"], 'cosine similarity score (TfidfVectorizer)': df_cosine_tfidf.loc["doc_1", "doc_2"]})
print("Calculate cosine similarity")
code number,cosine similarity score (CountVectorizer),cosine similarity score (TfidfVectorizer)
1,0.9829590656993242,0.9669012267796804
2,0.9905181941480535,0.9881702658468443
3,0.9579932440261575,0.9360869024484828
4,0.6642038652847895,0.5672351996087892
5,0.5179533782291896,0.40291367040422665
6,0.4319312868010303,0.31829582177823007
7,0.5896482075326526,0.4382943894216728
8,0.35163633764743635,0.22466601519503562
9,0.967150520150034,0.9492452906461634
10,0.9859071344374633,0.9803953956218493
11,0.9529220779220777,0.9110880036477345
12,0.7519060529531866,0.6711803046733371
13,0.9562166326624674,0.9303368295815007
14,0.5721208860316485,0.4145838273203879
15,0.6168762782575283,0.5083145112170189
16,0.8044818277847895,0.6756519867952987
17,0.995727575352705,0.9918137796456394
18,0.8690955388663847,0.7718155973688722
19,0.8406891761549491,0.7652744089784883
20,0.8640699102191518,0.801415428263494
21,0.9833092087209556,0.9689895110588486
22,0.9515953832852092,0.9367795707002863
23,0.49169966286739647,0.38166351453176733
24,0.9635293180817429,0.9304757734604177
25,0.9461817959669848,0.9239677244327656
26,0.404623693746184,0.2906781234742853
27,0.8574719598590294,0.809708060160666
28,0.7348469228349535,0.6306777057184418
29,0.6929202119807368,0.5853719913353513
30,0.34829888285959554,0.23780907592894035
31,0.9565692372573213,0.9288467058303025
32,0.969166010214795,0.9460906606477772
33,0.9122455416873133,0.8650673921100038
34,0.9676633447888103,0.9461259532586188
35,0.6259698357701571,0.5026188038264596
36,0.9899486164622846,0.9874566875086238
37,0.8963720835433008,0.8206547886714061
38,0.828618615335137,0.7686977099905108
39,0.8960787490449202,0.8266887304757253
40,0.46483484010068415,0.3252126926673298
41,0.9424949289491302,0.9112053490715386
42,0.7729715070359924,0.6574654448019361
43,0.7008262319629648,0.5834363407606968
44,0.799168959218739,0.7918816341208937
45,0.33887715959556886,0.24967743734456235
46,0.655671427128544,0.5336658704807009
47,0.9442278164425926,0.9082977098881965
48,0.9985945179280502,0.9972293990417763
49,0.7884449104233787,0.6770542320965478
50,0.7467528773091743,0.6305537845363173
51,0.7155668192482633,0.5916404085815423
52,0.8214486599543218,0.7260953596418699
53,0.9644917517533788,0.9542247492020246
54,0.5612091839337283,0.44493742743571235
55,0.8885542168674707,0.8014403316244239
56,0.9656040244825144,0.9441698275152057
57,0.9852675084097874,0.9805833006882317
58,0.9130854212497931,0.8463744693755527
59,0.8314876315471741,0.7838899214351002
60,0.46245444833152677,0.3421745944258118
61,0.9920588282895633,0.989943710795231
62,0.8556199275317793,0.762263661274845
63,0.5120537441696312,0.372159750190583
64,0.27383278566830965,0.1987313324015373
65,0.8796894875233187,0.8670708710428576
66,0.9285255785984343,0.8830059870247231
67,0.7206804901169632,0.5972163537537634
68,0.5521610645599694,0.41810575047485454
69,0.5318442867591466,0.44176882006729185
70,0.7980378792007151,0.7028409227791071
71,0.9067452960520181,0.864664550178747
72,0.9490407233483938,0.916263112180336
73,0.8657059739530922,0.825247299380852
74,0.7419604810332548,0.6480269660878143
75,0.9334226886211596,0.8922630189032867
76,0.6763264125613534,0.5165066071060762
77,0.005713912572995367,0.00406552862555413
78,0.8409727994175388,0.764123742925926
79,0.8244661756273934,0.7418756052644963
80,0.7499540005176346,0.6517902018541523
81,0.9227053833742914,0.9026126945641325
82,0.9279609284904535,0.8693826593938924
83,0.8485281374238571,0.7478786017823759
84,0.7617025894181404,0.6506611788881385
85,0.9930317775253296,0.9891589943448454
86,0.9620736927999172,0.938520117613679
87,0.41478067789217016,0.31077653333922794
88,0.7311587354999313,0.6181338057867134
89,0.41045953246761585,0.2882055468618911
code number,cosine similarity score (CountVectorizer),cosine similarity score (TfidfVectorizer)
1,0.9248352667510482,0.888898076952228
2,0.9999999999999991,0.9999999999999997
3,0.9599107066985334,0.9411287746868221
4,0.44664979757141304,0.3155061726034177
5,0.903225806451613,0.8253257937898935
6,0.9541851959520042,0.913372655998072
7,0.8628394488506886,0.8190029868630698
8,0.6408069008390094,0.47457001254098174
9,0.4613912046706676,0.3483062654237991
10,1.0000000000000004,1.0000000000000002
11,0.9276808410695255,0.8806075722930505
12,0.7252273450266299,0.6325394044080468
13,0.8677185216256361,0.801425146374899
14,0.959183673469388,0.9224609847314504
15,0.9186579947707756,0.8717953085014636
16,0.9481631865541554,0.911913679270532
17,0.993909870697976,0.9899620646788805
18,0.9988597897288605,0.997888706369987
19,0.7774172399551486,0.6765754964897805
20,0.7220214660489161,0.6153551916720854
21,0.6976214523986426,0.5618385369931845
22,0.9024461873830468,0.8740726108199721
23,0.7287564307374957,0.5854233313161834
24,0.9117647058823527,0.8395169509190922
25,0.844672423245055,0.7908637379977258
26,0.7881667843105441,0.6733664940587953
27,0.8777390474228787,0.8366886588400855
28,0.6116777418411966,0.5742880257524781
29,0.9014544973096625,0.8676851792366456
30,0.769233891102556,0.6394178287275498
31,0.9215231752008689,0.8794345955029164
32,0.8903935066381196,0.8457861852772193
33,0.9697142417369154,0.9513369873659678
34,0.9346274370747527,0.8884845542037165
35,0.9740382707330959,0.9500889294058446
36,1.0000000000000007,1.0
37,0.8219512259484445,0.7266984763029041
38,0.9335534674397444,0.8878784137731
39,0.8044287632446305,0.6926560600740324
40,0.5548278171588212,0.4405325123097934
41,0.87726811294792,0.8224426989137398
42,0.8842978932443868,0.8144139698349708
43,0.8767080593199744,0.814911981448098
44,0.9498748669880929,0.9192492652088391
45,0.8645346421168717,0.7850328883185647
46,0.8363636363636364,0.7212519607199548
47,0.9623459052686925,0.9385372243942065
48,0.9957815756240276,0.9917032206773166
49,0.99172244532499,0.9860679981436891
50,0.8090071484586969,0.7149873893187821
51,1.0000000000000002,1.0000000000000002
52,0.8422647207140463,0.7583934851593025
53,1.0000000000000002,1.0
54,0.9998032851362889,0.9996115345882909
55,0.9969879518072298,0.9940676321058516
56,0.9648469191870738,0.9404758558006368
57,0.9877374816795109,0.9858683052218079
58,0.5825149464430002,0.4252245265169706
59,0.8874276947019175,0.8534573078424692
60,0.9628328483303837,0.9554304079439051
61,0.9999999999999998,1.0000000000000002
62,0.9883720930232559,0.9772885929743134
63,0.7792101505892761,0.6501322504376899
64,0.9375,0.8836351388995087
65,0.9614704275628803,0.9465228335025084
66,0.9084105570318233,0.8520679532750199
67,1.0000000000000009,1.0
68,0.6498832197686679,0.5339988308414109
69,0.926568460146416,0.8972226458097631
70,0.9221056660207593,0.8676993775573723
71,0.9088280368981341,0.8598390766871996
72,0.571693736843898,0.4361184986892866
73,0.9323326521012025,0.8874752500531621
74,0.8510168175933222,0.7900188375601348
75,0.9543584952011676,0.9147442703676752
76,0.7468105728257773,0.721309528447763
77,0.9033585097851496,0.8793524777449662
78,0.9792938238560599,0.9673718012883574
79,0.9509220540435406,0.9208645388180952
80,0.9654224289254271,0.9423025200644959
81,0.9333753694122966,0.8948613536801993
82,0.9652509652509651,0.9336092202270857
83,0.9707253433941506,0.9562617132099339
84,0.9638311194183349,0.9545145804083411
85,0.9642746420003188,0.9395344485215906
86,0.9984120682810639,0.9968681578662569
87,0.6236095644623234,0.49724916915987416
88,0.9564910135147456,0.9246273025519147
89,0.4703476359705358,0.3490911427687785
...@@ -6,7 +6,7 @@ import glob ...@@ -6,7 +6,7 @@ import glob
csv_files = sorted(glob.glob("eval_reports/qwen-base-responses-evaluation-pass*.csv"))[:10] csv_files = sorted(glob.glob("eval_reports/qwen-base-responses-evaluation-pass*.csv"))[:10]
print(csv_files) print(csv_files)
# Create a subplot grid: 2 rows x 5 columns for 10 plots # Create a subplot grid: 2 rows x 5 columns for 10 plots
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(30, 8)) fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(10, 30))
axes = axes.flatten() axes = axes.flatten()
for i, csv_file in enumerate(csv_files): for i, csv_file in enumerate(csv_files):
...@@ -40,5 +40,5 @@ for i, csv_file in enumerate(csv_files): ...@@ -40,5 +40,5 @@ for i, csv_file in enumerate(csv_files):
ax.tick_params(axis='x', rotation=45) ax.tick_params(axis='x', rotation=45)
plt.tight_layout() plt.tight_layout()
plt.savefig("Prometheus Evaluation Results - Base Model.pdf") plt.savefig("Prometheus Evaluation Results - Base Model - spaced.pdf")
plt.show() plt.show()
#!/usr/bin/zsh
### Add basic configuration for job
#SBATCH --job-name=cosine_simiarity_evaluation
#SBATCH --output=logs/cosine_simiarity_evaluation_analysis%j.log
#SBATCH --error=logs/cosine_simiarity_evaluation_analysis_error_%j.log
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=1
#SBATCH --time=00:30:00
###------------------------------------------------------------------------------------------------------------------------------
### Run the project in work directory of the cluster (configure based on need!!
### RWTH File System : https://help.itc.rwth-aachen.de/en/service/rhr4fjjutttf/article/da307ec2c60940b29bd42ac483fc3ea7/
cd $HPCWORK
cd codebud/evaluation
###------------------------------------------------------------------------------------------------------------------------------
### JOB SCRIPT RUN
module load GCCcore/.13.2.0
module load Python/3.11.5
module load CUDA
source ../../venvs/codebud/bin/activate
echo $VIRTUAL_ENV
python --version
python cosine_similarity_evaluation.py
module unload CUDA
module unload Python/3.11.5
deactivate
echo "Script ran successfully"
\ No newline at end of file
...@@ -19,7 +19,7 @@ model, tokenizer = FastLanguageModel.from_pretrained( ...@@ -19,7 +19,7 @@ model, tokenizer = FastLanguageModel.from_pretrained(
) )
FastLanguageModel.for_inference(model) # Enable native 2x faster inference FastLanguageModel.for_inference(model) # Enable native 2x faster inference
reference_dataset = load_dataset("atharva2721/qwen_inference_output", split="train", trust_remote_code=True) reference_dataset = load_dataset("atharva2721/codebud-test-dataset", split="train", trust_remote_code=True)
inference_output = [] inference_output = []
code_no = 0 code_no = 0
inferred_no = 0 inferred_no = 0
...@@ -49,9 +49,11 @@ for example in reference_dataset: ...@@ -49,9 +49,11 @@ for example in reference_dataset:
should_retry = True should_retry = True
retry_no = 0 retry_no = 0
while should_retry: while should_retry:
print(f'Trying {code_no} for {retry_no} time', flush=True)
output_tensor = model.generate(input_ids = inputs, output_tensor = model.generate(input_ids = inputs,
max_length = max_seq_length, max_length = max_seq_length,
temperature = 0.6 temperature = 0.6,
repetition_penalty = 1.1
) )
decoded = tokenizer.batch_decode(output_tensor) decoded = tokenizer.batch_decode(output_tensor)
...@@ -59,10 +61,11 @@ for example in reference_dataset: ...@@ -59,10 +61,11 @@ for example in reference_dataset:
output = "" output = ""
for text in decoded: for text in decoded:
output += text output += text
print(output)
output = output.split('<|im_start|>assistant') output = output.split('<|start_header_id|>assistant<|end_header_id|>')
if len(output) == 2: if len(output) == 2:
output = output[1] output = output[1]
print(output)
code_pattern = r'\[refined_C#\](.*?)\[/refined_C#\]' code_pattern = r'\[refined_C#\](.*?)\[/refined_C#\]'
summary_pattern = r'\[code_changes\](.*?)\[/code_changes\]' summary_pattern = r'\[code_changes\](.*?)\[/code_changes\]'
...@@ -73,20 +76,18 @@ for example in reference_dataset: ...@@ -73,20 +76,18 @@ for example in reference_dataset:
if code_matches and summary_matches: if code_matches and summary_matches:
refined_code = code_matches.group(1) refined_code = code_matches.group(1)
summary = summary_matches.group(1) summary = summary_matches.group(1)
inference_output.append({'code': example["code"], 'base inference':refined_code, 'base summary': summary,'finetuned inference': example["finetuned inference"], 'finetuned summary': example["finetuned summary"], 'reference inference': example["reference inference"], 'reference summary': example["reference summary"]}) inference_output.append({'code': example["code"], 'finetuned inference':refined_code, 'finetuned summary': summary, 'reference inference': example["refined code"], 'reference summary': example["summary"]})
print(f'Code no. {code_no} refined successfully', flush=True) print(f'Code no. {code_no} refined successfully', flush=True)
should_retry = False should_retry = False
inferred_no += 1 inferred_no += 1
if retry_no == 2 and should_retry: if retry_no == 2:
should_retry = False should_retry = False
print(f'Failed to refine code at {code_no}. Final try output: \n [failed_output]{output}[/failed_output]', flush=True) print(f'Failed to refine code at {code_no}. Final try output: \n [failed_output]{output}[/failed_output]', flush=True)
failed_no +=1 failed_no +=1
retry_no += 1 retry_no += 1
if code_no == 5:
break
# new_dataset = Dataset.from_generator(data_generator, gen_kwargs={"dataset": inference_output}) new_dataset = Dataset.from_generator(data_generator, gen_kwargs={"dataset": inference_output})
# new_dataset.push_to_hub('llama_inference_output') new_dataset.push_to_hub('llama_inference_output')
# print(f'Created and pushed total of {inferred_no} examples from total of {code_no} codes. Total failed inferences are {failed_no}', flush=True) print(f'Created and pushed total of {inferred_no} examples from total of {code_no} codes. Total failed inferences are {failed_no}', flush=True)
\ No newline at end of file \ No newline at end of file
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#SBATCH --ntasks=1 #SBATCH --ntasks=1
#SBATCH --cpus-per-task=3 #SBATCH --cpus-per-task=3
#SBATCH --gres=gpu:1 #SBATCH --gres=gpu:1
#SBATCH --time=1:00:00 #SBATCH --time=05:00:00
###------------------------------------------------------------------------------------------------------------------------------ ###------------------------------------------------------------------------------------------------------------------------------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment