sampling.py 1.5 KB
Newer Older
Nishtha Jain's avatar
Nishtha Jain committed
1
2
from sklearn.model_selection import train_test_split
import pymongo
Nishtha Jain's avatar
Nishtha Jain committed
3
from config import MONGO_HOST, MONGO_DB, MONGO_COLLECTION, CLASS_GROUP, DATASET_NAMES, SEED
Nishtha Jain's avatar
Nishtha Jain committed
4
5
6
7
from joblib import dump, load


def get_data_from_mongo(class_group):
Nishtha Jain's avatar
Nishtha Jain committed
8
9
	print('processing sampling.get_data_from_mongo ...')

Nishtha Jain's avatar
Nishtha Jain committed
10
11
12
	client = pymongo.MongoClient(MONGO_HOST)
	collection = client[MONGO_DB][MONGO_COLLECTION]

Nishtha Jain's avatar
Nishtha Jain committed
13
	data =  list(collection.find({'$or':[{'title':title} for title in CLASS_GROUP[class_group]]}))
Nishtha Jain's avatar
Nishtha Jain committed
14

Nishtha Jain's avatar
Nishtha Jain committed
15
	print("\t saving file : ",DATASET_NAMES['datasets',class_group])
16
	dump(data, DATASET_NAMES['datasets',class_group]+'.joblib')
Nishtha Jain's avatar
Nishtha Jain committed
17
18
19
	return(data)


Nishtha Jain's avatar
Nishtha Jain committed
20
21
22
def load_data(class_group, from_saved=True):
	print('processing sampling.load_data ...')

Nishtha Jain's avatar
Nishtha Jain committed
23
24
25
	if not from_saved:
		data = get_data_from_mongo(class_group)
	else:
Nishtha Jain's avatar
Nishtha Jain committed
26
		data = load(DATASET_NAMES['datasets',class_group]+'.joblib')
Nishtha Jain's avatar
Nishtha Jain committed
27
28
29
	return data

'''
Nishtha Jain's avatar
Nishtha Jain committed
30
sampling -> random, balanced(upsample, downsample, weighted)
Nishtha Jain's avatar
Nishtha Jain committed
31
'''
Nishtha Jain's avatar
Nishtha Jain committed
32
33
34
def data_selection(data, class_group, sampling, test_size, masking=True):
	print('processing sampling.data_selection ...')
		
Nishtha Jain's avatar
Nishtha Jain committed
35
36
37
38
39
	if sampling == 'random' :
		train,test = train_test_split(data, test_size = test_size, random_state = SEED)
	# same for now
	elif sampling == 'balanced':
		
Nishtha Jain's avatar
Nishtha Jain committed
40
41
		train,test = train_test_split(data, test_size = test_size, random_state = SEED)
	
Nishtha Jain's avatar
Nishtha Jain committed
42
	print("\t saving file : ",DATASET_NAMES['datasets',class_group,sampling,test_size])
43
	dump([train,test], DATASET_NAMES['datasets',class_group,sampling,test_size]+'.joblib')
Nishtha Jain's avatar
Nishtha Jain committed
44
	
Nishtha Jain's avatar
Nishtha Jain committed
45
46
47
		
	return train,test