sampling.py 1.32 KB
Newer Older
Nishtha Jain's avatar
Nishtha Jain committed
1
2
from sklearn.model_selection import train_test_split
import pymongo
Nishtha Jain's avatar
Nishtha Jain committed
3
from config import MONGO_HOST, MONGO_DB, MONGO_COLLECTION, CLASS_GROUP, DATASET_NAMES, SEED
Nishtha Jain's avatar
Nishtha Jain committed
4
5
6
7
from joblib import dump, load


def get_data_from_mongo(class_group):
Nishtha Jain's avatar
Nishtha Jain committed
8
9
	print('processing sampling.get_data_from_mongo ...')

Nishtha Jain's avatar
Nishtha Jain committed
10
11
12
	client = pymongo.MongoClient(MONGO_HOST)
	collection = client[MONGO_DB][MONGO_COLLECTION]

Nishtha Jain's avatar
Nishtha Jain committed
13
	data =  list(collection.find({'$or':[{'title':title} for title in CLASS_GROUP[class_group]]}))
Nishtha Jain's avatar
Nishtha Jain committed
14
15
16
17
18

	dump(data, DATASET_NAMES['datasets',class_group]+'.joblib')
	return(data)


Nishtha Jain's avatar
Nishtha Jain committed
19
20
21
def load_data(class_group, from_saved=True):
	print('processing sampling.load_data ...')

Nishtha Jain's avatar
Nishtha Jain committed
22
23
24
	if not from_saved:
		data = get_data_from_mongo(class_group)
	else:
Nishtha Jain's avatar
Nishtha Jain committed
25
		data = load(DATASET_NAMES['datasets',class_group]+'.joblib')
Nishtha Jain's avatar
Nishtha Jain committed
26
27
28
29
30
	return data

'''
sampling -> random, upsample, downsample, weighted
'''
Nishtha Jain's avatar
Nishtha Jain committed
31
32
33
def data_selection(data, class_group, sampling, test_size, masking=True):
	print('processing sampling.data_selection ...')
		
Nishtha Jain's avatar
Nishtha Jain committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
	if sampling == 'random':
		train,test = train_test_split(data, test_size = test_size, random_state = SEED)
	
	elif sampling == 'downsample':
		# TO-DO:
		train,test = None,None
	else:
		# TO-DO:
		train,test = None,None
	
	dump([train,test], DATASET_NAMES['datasets',class_group,sampling,test_size]+'.joblib')
		
	return train,test