Skip to content
Snippets Groups Projects
Commit 4ddff3d4 authored by Rawel's avatar Rawel
Browse files

temporarily added method to collect missing training data

parent 075a1733
No related branches found
No related tags found
No related merge requests found
import glob
import os
import random
from multiprocessing import Pool
from joblib import load
from Classifier.commit_features import CommitFeatures
from Data.Database.db_repository import DBRepository
from Data.Utils.CommitUtils import CommitUtils
from Data.Utils.utils import get_config_nodes_repo_dict
abspath = os.path.dirname(os.path.abspath(__file__))
os.chdir(abspath)
def todo_jobs():
vcc_distribution = load("vcc_distribution.joblib")
todo = []
for config_code in vcc_distribution:
for year in vcc_distribution[config_code]:
unclassified_files = glob.glob(f"Training/unclassified/{config_code}/{year}/*.json")
if len(unclassified_files) < vcc_distribution[config_code][year] * 2:
missing = vcc_distribution[config_code][year] * 2 - len(unclassified_files)
todo.append((config_code, year, missing))
print(todo)
return todo
def choose_random(config_code, year):
print(f"Choosing random for {config_code} in {year}")
db_repo = DBRepository()
repo_node = get_config_nodes_repo_dict()[config_code]
commit_utils = CommitUtils(False, repo_node, None)
vccs = set(db_repo.get_all_vccs())
commits = commit_utils.get_commits_between_years(year, year)
commits = list(filter(lambda com: com.hexsha not in vccs, commits))
while True:
try:
random_commit = random.choice(commits)
commit_features = CommitFeatures(repo_node.find("./path").text, random_commit.hexsha)
commit_features.extract_features()
commit_features.save_features_to_json(f"Training/unclassified/{config_code}/{year}")
db_repo.save_commit(random_commit, config_code)
print(f"Saved {random_commit.hexsha} for {config_code} in {year}")
break
except ValueError as e:
print(e)
def main():
todos = todo_jobs()
tasks = []
for todo in todos:
for _ in range(todo[2]):
tasks.append((todo[0], todo[1]))
p = Pool()
p.starmap(choose_random, tasks)
p.close()
p.join()
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment