Newer
Older
# -*- coding: utf-8 -*-
"""
Data Manager Class.
Controls data source selection.
Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
import h5py # Used for hdf5 database
from pyratk.acquisition.mux_buffer import MuxBuffer
from pyratk.acquisition.virtual_daq import VirtualDAQ
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# COMPRESSION OPTIONS
COMPRESSION = "gzip"
COMPRESSION_OPTS = 9
class DataManager(MuxBuffer):
"""
DataManager class; extends MuxBuffer class.
Handles multiple data input sources and aggrigates them into one data
"mux" which can select from the various input sources added.
"""
def __init__(self, db="default.hdf5", daq=None):
"""
Initialize DataManager Class.
Arguments:
db (optional)
database to save/load from
"""
self.db = None
self.samples = None # Pointer to the samples dataset
self.labels = None # Pointer to the labels dataset
# MuxBuffer attribute initialization
self.source_list = []
# Open virtual daq for playback
self.virt_daq = VirtualDAQ()
self.add_source(self.virt_daq)
self.daq = daq
# DEVEL/DEBUG
try:
self.open_database(db)
except Exception as e:
print("(DataManager) Error opening debug database:", e)
# === DATABASE ============================================================
def open_database(self, db_path, mode='a'):
"""
Accept file path of a database to load and open it.
db_path: string
Path to file to open
mode: string
File access mode (default: 'a')
"""
# If a database is open, close it then open a new one
if self.db is not None:
self.db.close()
# Attempt to open or create a new file with the specified name
try:
self.db = h5py.File(db_path, mode)
except OSError as e:
print("(DataManager) Error opening HDF5 File:", e)
# Attempt to open 'samples', 'labels', 'subjects' datasets
# Create new datsets on keyerror
try:
self.samples = self.db["samples"]
except KeyError:
print("No 'samples' group found. Creating 'samples'...")
self.samples = self.db.create_group("samples")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.labels = self.db["labels"]
except KeyError:
print("No 'labels' group found. Creating 'labels'")
self.labels = self.db.create_group("labels")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.subjects = self.db["subjects"]
except KeyError:
print("No 'subjects' group found. Creating 'subjects'")
self.subjects = self.db.create_group("subjects")
except Exception as e:
# Handle all other exceptions
print(e)
# === DATASET CONTROL =====================================================
def load_dataset(self, ds):
"""Load dataset into virtualDAQ and set to virtualDAQ source."""
self.virt_daq.load_dataset(ds)
self.set_source(self.virt_daq)
def get_datasets(self):
"""Return list of all dataset objects in 'samples' dataset."""
keys = []
if self.db:
for key in self.samples:
keys.append(self.samples[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def delete_dataset(self, ds):
"""Remove dataset from database."""
try:
del self.db[ds.name]
except Exception as e:
print("Error deleting dataset: ", e)
def save_buffer(self, name, labels, subject, notes):
"""
Write buffer contents to dataset with specified 'name'.
If no name is provided, 'sample_n' will be used, where n is the index
of the sample relative to 0.
"""
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
# Does not exist, create new entry
try:
# Save buffer data
Jason Merlo
committed
ds = self.samples.create_dataset(
name, data=self.source.ts_buffer,
compression=COMPRESSION,
compression_opts=COMPRESSION_OPTS)
except Exception as e:
print("(DataManager) Error saving dataset: ", e)
# Create hard links to class label group and subject group
if labels is None:
labels_str = ''
else:
labels_str = [] # used for saving labels to attribute
for label in labels:
Jason Merlo
committed
# Get type of label
if type(label) is str:
label_name = label
else:
label_name = label.name
# print('Checking if',
# "labels/{:}/{:}".format(label_name, name), 'exists')
if "{:}/{:}".format(label_name, name) not in self.labels:
self.labels[label_name][name] = ds
labels_str.append(label_name.split('/')[-1])
Jason Merlo
committed
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
if "{:}/{:}".format(subject_name, name) not in self.subjects:
self.subjects[subject_name][name] = ds
if notes is None:
notes = ''
try:
# Save attribute data
attrs = self.samples[name].attrs
attrs.create("sample_rate", self.source.sample_rate)
Jason Merlo
committed
attrs.create("sample_size", self.source.sample_chunk_size)
attrs.create("daq_type", self.source.daq_type.encode('utf8'))
attrs.create("num_channels", self.source.num_channels)
attrs.create("label", labels_str.encode('utf8'))
if subject:
Jason Merlo
committed
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
Jason Merlo
committed
subject_name.split('/')[-1].encode('utf8'))
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
attrs.create("notes", notes.encode('utf8'))
except Exception as e:
print("(DataManager) Error saving attributes: ", e)
def remove_attributes(self, name, labels, subject):
"""
Remove hard-linked created in attribute folders.
Currently the only hard-linked attribues are: labels, subject
"""
# Attempt to open dataset
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
print("(DataManager) Error attempting to remove attributes.")
print("(DataManager) Dataset does not exist.")
# Attempt to delete sample from label dataset
for label in labels:
try:
self.labels[label].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in label dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# Attempt to delete sample from subjects dataset
if subject:
try:
self.subjects[subject].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in subject dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# --- labels --- #
def get_labels(self):
"""Return list of all label objects in 'labels' dataset."""
keys = []
if self.db:
for key in self.labels:
keys.append(self.labels[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_label(self, label):
"""Add a label to dataset."""
try:
self.labels.create_group(label)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_label(self):
"""TODO: implement remove_label."""
pass
# --- subjects --- #
def get_subjects(self):
"""Return list of all subject objects in dataset."""
keys = []
if self.db:
for key in self.subjects:
keys.append(self.subjects[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_subject(self, subject):
"""Add a subject to dataset."""
try:
self.subjects.create_group(subject)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_subject():
# TODO: implement remove_subject
pass
# === DATA CONTROL ========================================================
def reset(self):
"""Reset DAQ manager, clear all data and graphs."""
self.source.reset()
def pause_toggle(self):
"""Pauses the DAQ manager."""
# Virtual DAQ needs a dataset loaded before running
if self.source is not self.virt_daq or self.virt_daq.ds is not None:
if self.paused is True:
# self.reset() # used for arbitrary dt
self.paused = False
else:
self.paused = True
def close(self):
"""Close the selected object in the DAQ manager."""
for source in self.source_list:
source.close()
# === PROPERTIES ====================================================
@property
def reset_flag(self):
return self.source.reset_flag
@reset_flag.setter
def reset_flag(self, a):
self.source.reset_flag = a
@property
def ts_buffer(self):
return self.source.ts_buffer