Newer
Older
# -*- coding: utf-8 -*-
"""
Data Manager Class.
Controls data source selection.
Author: Jason Merlo
Maintainer: Jason Merlo (merlojas@msu.edu)
"""
import h5py # Used for hdf5 database
from pyratk.acquisition.mux_buffer import MuxBuffer
from pyratk.acquisition.virtual_daq import VirtualDAQ
from pyratk.acquisition import daq # Extention of DAQ object
from pyqtgraph import QtCore
from pyratk.formatting import warning
# COMPRESSION OPTIONS
COMPRESSION = "gzip"
COMPRESSION_OPTS = 9
class DataManager(MuxBuffer):
"""
DataManager class; extends MuxBuffer class.
Handles multiple data input sources and aggrigates them into one data
"mux" which can select from the various input sources added.
"""
reset_signal = QtCore.pyqtSignal()
def __init__(self, db="default.hdf5", daq=None):
"""
Initialize DataManager Class.
Arguments:
db (optional)
database to save/load from
"""
self.db = None
self.samples = None # Pointer to the samples dataset
self.labels = None # Pointer to the labels dataset
# Open virtual daq for playback
self.virt_daq = VirtualDAQ()
self.add_source(self.virt_daq)
self.daq = daq
self.source_reset_signal.connect(self.reset_signal.emit)
# DEVEL/DEBUG
try:
self.open_database(db)
except Exception as e:
print("(DataManager) Error opening debug database:", e)
raise RuntimeError("Cannot open specified dataset - is it already"
" opened elsewhere?")
# === DATABASE ============================================================
def open_database(self, db_path, mode='a'):
"""
Accept file path of a database to load and open it.
db_path: string
Path to file to open
mode: string
File access mode (default: 'a')
"""
# If a database is open, close it then open a new one
if self.db is not None:
self.db.close()
# Attempt to open or create a new file with the specified name
try:
self.db = h5py.File(db_path, mode)
except OSError as e:
print("(DataManager) Error opening HDF5 File:", e)
raise RuntimeError("Cannot open or create specified dataset - is "
"it already opened elsewhere?")
# Attempt to open 'samples', 'labels', 'subjects' datasets
# Create new datsets on keyerror
try:
self.samples = self.db["samples"]
except KeyError:
print("No 'samples' group found. Creating 'samples'...")
self.samples = self.db.create_group("samples")
except Exception as e:
# Handle all other exceptions
print(e)
# Create ground truth data group
try:
self.trajectories = self.db["trajectories"]
except KeyError:
print("No 'trajectories' group found. Creating 'trajectories'...")
self.trajectories = self.db.create_group("trajectories")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.labels = self.db["labels"]
except KeyError:
print("No 'labels' group found. Creating 'labels'")
self.labels = self.db.create_group("labels")
except Exception as e:
# Handle all other exceptions
print(e)
try:
self.subjects = self.db["subjects"]
except KeyError:
print("No 'subjects' group found. Creating 'subjects'")
self.subjects = self.db.create_group("subjects")
except Exception as e:
# Handle all other exceptions
print(e)
# === DATASET CONTROL =====================================================
def load_dataset(self, ds):
"""Load dataset into virtualDAQ and set to virtualDAQ source."""
# load trajectory if available
if 'trajectory' in ds.attrs.keys():
self.trajectory_label = ds.attrs['trajectory'].decode('utf-8')
# Open trajectory dataset
try:
ts = self.trajectories[ds.name.split('/')[-1]]
self.virt_daq.load_trajectory(ts)
except KeyError as e:
warning('Error loading ground-truth trajectory: {:}'.format(e))
def get_datasets(self):
"""Return list of all dataset objects in 'samples' dataset."""
keys = []
if self.db:
for key in self.samples:
keys.append(self.samples[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def get_database(self):
"""Return the database object."""
return self.db
def delete_dataset(self, ds):
"""Remove dataset from database."""
try:
del self.db[ds.name]
except Exception as e:
print("Error deleting dataset: ", e)
def get_trajectories(self):
"""Return list of all dataset objects in 'trajectories' dataset."""
keys = []
if self.db:
for key in self.trajectories:
keys.append(self.trajectories[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def delete_trajectory(self, ds):
"""Remove trajectory from database."""
try:
del self.db[ds.name]
except Exception as e:
print("Error deleting dataset: ", e)
def save_buffer(self, name, labels, subject, notes):
"""
Write buffer contents to dataset with specified 'name'.
If no name is provided, 'sample_n' will be used, where n is the index
of the sample relative to 0.
Will save trajectory_samples as ground truth if it exists.
"""
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
# Does not exist, create new entry
try:
# Save buffer data
Jason Merlo
committed
ds = self.samples.create_dataset(
name, data=self.source.ts_buffer,
compression=COMPRESSION,
compression_opts=COMPRESSION_OPTS)
except Exception as e:
print("(DataManager) Error saving dataset: ", e)
# Create ground truth data if available
if hasattr(self.source, 'trajectory_samples'):
if "/trajectories/{:}".format(name) in self.db:
trajectory_ds = self.trajectories[name]
else:
# Does not exist, create new entry
try:
# Save buffer data
trajectory_ds = self.trajectories.create_dataset(
name, data=self.source.trajectory_samples,
compression=COMPRESSION,
compression_opts=COMPRESSION_OPTS)
attrs = self.trajectories[name].attrs
attrs.create("coordinate_type", self.source.coordinate_type.encode('utf-8'))
except Exception as e:
print("(DataManager) Error saving trajectory: ", e)
# Create hard links to class label group and subject group
if labels is None:
labels_str = ''
else:
labels_str = [] # used for saving labels to attribute
for label in labels:
Jason Merlo
committed
# Get type of label
if type(label) is str:
label_name = label
else:
label_name = label.name
# print('Checking if',
# "labels/{:}/{:}".format(label_name, name), 'exists')
if "{:}/{:}".format(label_name, name) not in self.labels:
self.labels[label_name][name] = ds
labels_str.append(label_name.split('/')[-1])
Jason Merlo
committed
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
if "{:}/{:}".format(subject_name, name) not in self.subjects:
self.subjects[subject_name][name] = ds
if notes is None:
notes = ''
try:
# Save attribute data
attrs = self.samples[name].attrs
attrs.create("sample_rate", self.source.sample_rate)
Jason Merlo
committed
attrs.create("sample_size", self.source.sample_chunk_size)
attrs.create("daq_type", self.source.daq_type.encode('utf8'))
attrs.create("num_channels", self.source.num_channels)
attrs.create("label", labels_str.encode('utf8'))
if hasattr(self.source, 'trajectory_samples'):
attrs.create("trajectory", name.encode('utf8'))
Jason Merlo
committed
if type(subject) is str:
subject_name = subject
else:
subject_name = subject.name
Jason Merlo
committed
subject_name.split('/')[-1].encode('utf8'))
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
attrs.create("notes", notes.encode('utf8'))
except Exception as e:
print("(DataManager) Error saving attributes: ", e)
def remove_attributes(self, name, labels, subject):
"""
Remove hard-linked created in attribute folders.
Currently the only hard-linked attribues are: labels, subject
"""
# Attempt to open dataset
if "/samples/{:}".format(name) in self.db:
ds = self.samples[name]
else:
print("(DataManager) Error attempting to remove attributes.")
print("(DataManager) Dataset does not exist.")
# Attempt to delete sample from label dataset
for label in labels:
try:
self.labels[label].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in label dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# Attempt to delete sample from subjects dataset
if subject:
try:
self.subjects[subject].pop(name)
except KeyError:
print("(DataManager) Error attempting to remove sample \
hard-link in subject dataset.")
print("(DataManager) sample does not exist in dataset.")
print("Name: {:}".format(name))
# --- labels --- #
def get_labels(self):
"""Return list of all label objects in 'labels' dataset."""
keys = []
if self.db:
for key in self.labels:
keys.append(self.labels[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_label(self, label):
"""Add a label to dataset."""
try:
self.labels.create_group(label)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_label(self):
"""TODO: implement remove_label."""
pass
# --- subjects --- #
def get_subjects(self):
"""Return list of all subject objects in dataset."""
keys = []
if self.db:
for key in self.subjects:
keys.append(self.subjects[key])
return keys
else:
print("(DataManager) Database must be loaded before datasets read")
def add_subject(self, subject):
"""Add a subject to dataset."""
try:
self.subjects.create_group(subject)
except Exception as e:
# Handle all exceptions
print('(DataManager)', e)
def remove_subject():
# TODO: implement remove_subject
pass
# === DATA CONTROL ========================================================
def reset(self):
"""Reset DAQ manager, clear all data and graphs."""
self.source.paused = True
self.source.reset()
self.reset_signal.emit()
# self.source.paused = False
def pause_toggle(self):
"""Pauses the DAQ manager."""
# Virtual DAQ needs a dataset loaded before running
if self.source is not self.virt_daq or self.virt_daq.ds is not None:
if self.source.paused is True:
self.source.paused = False
self.source.paused = True
def close(self):
"""Close the selected object in the DAQ manager."""
for source in self.source_list:
source.close()
# === PROPERTIES ====================================================
# @property
# def reset_flag(self):
# return self.source.reset_flag
#
# @reset_flag.setter
# def reset_flag(self, a):
# self.source.reset_flag = a
@property
def ts_buffer(self):
return self.source.ts_buffer