Newer
Older
# -*- coding: utf-8 -*-
"""
Data Manager Class.
Controls data source selection.
Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
import h5py # Used for hdf5 database
from pyratk.acquisition.mux_buffer import MuxBuffer
from pyratk.acquisition.virtual_daq import VirtualDAQ
from pyratk.acquisition import daq # Extention of DAQ object
# COMPRESSION OPTIONS
COMPRESSION = "gzip"
COMPRESSION_OPTS = 9
class DataManager(MuxBuffer):
"""
DataManager class; extends MuxBuffer class.
Handles multiple data input sources and aggrigates them into one data
"mux" which can select from the various input sources added.
"""
def __init__(self, db="default.hdf5", daq=None):
"""
Initialize DataManager Class.
Arguments:
db (optional)
database to save/load from
"""
self.db = None
self.samples = None # Pointer to the samples dataset
self.labels = None # Pointer to the labels dataset
# Open virtual daq for playback
self.virt_daq = VirtualDAQ()
self.add_source(self.virt_daq)
self.daq = daq
# DEVEL/DEBUG
try:
self.open_database(db)
except Exception as e:
print("(DataManager) Error opening debug database:", e)
raise RuntimeError("Cannot open specified dataset - is it already"
" opened elsewhere?")
# === DATABASE ============================================================
def open_database(self, db_path, mode='a'):
"""
Accept file path of a database to load and open it.
db_path: string
Path to file to open
mode: string
File access mode (default: 'a')
"""
# If a database is open, close it then open a new one
if self.db is not None:
self.db.close()
# Attempt to open or create a new file with the specified name
try:
self.db = h5py.File(db_path, mode)
except OSError as e:
print("(DataManager) Error opening HDF5 File:", e)
raise RuntimeError("Cannot open or create specified dataset - is "
"it already opened elsewhere?")
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Attempt to open 'samples', 'labels', 'subjects' datasets
# Create new datsets on keyerror
try:
self.samples = self.db["samples"]
except KeyError:
print("No 'samples' group found. Creating 'samples'...")
self.samples = self.db.create_group("samples")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.labels = self.db["labels"]
except KeyError:
print("No 'labels' group found. Creating 'labels'")
self.labels = self.db.create_group("labels")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.subjects = self.db["subjects"]
except KeyError:
print("No 'subjects' group found. Creating 'subjects'")
self.subjects = self.db.create_group("subjects")
except Exception as e:
# Handle all other exceptions
print(e)
# === DATASET CONTROL =====================================================
def load_dataset(self, ds):
"""Load dataset into virtualDAQ and set to virtualDAQ source."""
self.virt_daq.load_dataset(ds)
self.set_source(self.virt_daq)
def get_datasets(self):
"""Return list of all dataset objects in 'samples' dataset."""
keys = []
if self.db:
for key in self.samples:
keys.append(self.samples[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def delete_dataset(self, ds):
"""Remove dataset from database."""
try:
del self.db[ds.name]
except Exception as e:
print("Error deleting dataset: ", e)
def save_buffer(self, name, labels, subject, notes):
"""
Write buffer contents to dataset with specified 'name'.
If no name is provided, 'sample_n' will be used, where n is the index
of the sample relative to 0.
"""
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
# Does not exist, create new entry
try:
# Save buffer data
Jason Merlo
committed
ds = self.samples.create_dataset(
name, data=self.source.ts_buffer,
compression=COMPRESSION,
compression_opts=COMPRESSION_OPTS)
except Exception as e:
print("(DataManager) Error saving dataset: ", e)
# Create hard links to class label group and subject group
if labels is None:
labels_str = ''
else:
labels_str = [] # used for saving labels to attribute
for label in labels:
Jason Merlo
committed
# Get type of label
if type(label) is str:
label_name = label
else:
label_name = label.name
# print('Checking if',
# "labels/{:}/{:}".format(label_name, name), 'exists')
if "{:}/{:}".format(label_name, name) not in self.labels:
self.labels[label_name][name] = ds
labels_str.append(label_name.split('/')[-1])
Jason Merlo
committed
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
if "{:}/{:}".format(subject_name, name) not in self.subjects:
self.subjects[subject_name][name] = ds
if notes is None:
notes = ''
try:
# Save attribute data
attrs = self.samples[name].attrs
attrs.create("sample_rate", self.source.sample_rate)
Jason Merlo
committed
attrs.create("sample_size", self.source.sample_chunk_size)
attrs.create("daq_type", self.source.daq_type.encode('utf8'))
attrs.create("num_channels", self.source.num_channels)
attrs.create("label", labels_str.encode('utf8'))
if subject:
Jason Merlo
committed
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
Jason Merlo
committed
subject_name.split('/')[-1].encode('utf8'))
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
attrs.create("notes", notes.encode('utf8'))
except Exception as e:
print("(DataManager) Error saving attributes: ", e)
def remove_attributes(self, name, labels, subject):
"""
Remove hard-linked created in attribute folders.
Currently the only hard-linked attribues are: labels, subject
"""
# Attempt to open dataset
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
print("(DataManager) Error attempting to remove attributes.")
print("(DataManager) Dataset does not exist.")
# Attempt to delete sample from label dataset
for label in labels:
try:
self.labels[label].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in label dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# Attempt to delete sample from subjects dataset
if subject:
try:
self.subjects[subject].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in subject dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# --- labels --- #
def get_labels(self):
"""Return list of all label objects in 'labels' dataset."""
keys = []
if self.db:
for key in self.labels:
keys.append(self.labels[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_label(self, label):
"""Add a label to dataset."""
try:
self.labels.create_group(label)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_label(self):
"""TODO: implement remove_label."""
pass
# --- subjects --- #
def get_subjects(self):
"""Return list of all subject objects in dataset."""
keys = []
if self.db:
for key in self.subjects:
keys.append(self.subjects[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_subject(self, subject):
"""Add a subject to dataset."""
try:
self.subjects.create_group(subject)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_subject():
# TODO: implement remove_subject
pass
# === DATA CONTROL ========================================================
def reset(self):
"""Reset DAQ manager, clear all data and graphs."""
self.source.reset()
def pause_toggle(self):
"""Pauses the DAQ manager."""
# Virtual DAQ needs a dataset loaded before running
if self.source is not self.virt_daq or self.virt_daq.ds is not None:
if self.paused is True:
# self.reset() # used for arbitrary dt
self.paused = False
else:
self.paused = True
def close(self):
"""Close the selected object in the DAQ manager."""
for source in self.source_list:
source.close()
# === PROPERTIES ====================================================
@property
def reset_flag(self):
return self.source.reset_flag
@reset_flag.setter
def reset_flag(self, a):
self.source.reset_flag = a
@property
def ts_buffer(self):
return self.source.ts_buffer