sampling.py 1.08 KB
Newer Older
Nishtha Jain's avatar
Nishtha Jain committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from sklearn.model_selection import train_test_split
import pymongo
from config import MONGO_HOST
from joblib import dump, load


def get_data_from_mongo(class_group):
	client = pymongo.MongoClient(MONGO_HOST)
	collection = client[MONGO_DB][MONGO_COLLECTION]

	data =  collection.find({'$or':[{'title':title} for title in CLASS_GROUP[class_group]]})

	dump(data, DATASET_NAMES['datasets',class_group]+'.joblib')
	return(data)


def load_data(from_saved=True):
	if not from_saved:
		data = get_data_from_mongo(class_group)
	else:
		data = load('datasets/'+DATASET_NAMES[class_group]+'.joblib')
	return data

'''
sampling -> random, upsample, downsample, weighted
'''
def data_selection(data, sampling, test_size, masking=True):
	
	
	if sampling == 'random':
		train,test = train_test_split(data, test_size = test_size, random_state = SEED)
	
	elif sampling == 'downsample':
		# TO-DO:
		train,test = None,None
	else:
		# TO-DO:
		train,test = None,None
	
	dump([train,test], DATASET_NAMES['datasets',class_group,sampling,test_size]+'.joblib')
		
	return train,test