diff --git a/biosppy/storage.py b/biosppy/storage.py index adbd7b26..4fbff053 100644 --- a/biosppy/storage.py +++ b/biosppy/storage.py @@ -278,7 +278,7 @@ def load_h5(path, label): def store_txt(path, data, sampling_rate=1000., resolution=None, date=None, - precision=6): + labels=None, precision=6): """Store data to a simple text file. Parameters @@ -286,37 +286,63 @@ def store_txt(path, data, sampling_rate=1000., resolution=None, date=None, path : str Path to file. data : array - Data to store. + Data to store (up to 2 dimensions). sampling_rate : int, float, optional Sampling frequency (Hz). resolution : int, optional Sampling resolution. date : datetime, str, optional Datetime object, or an ISO 8601 formatted date-time string. + labels : list, optional + Labels for each column of `data`. precision : int, optional Precision for string conversion. + Raises + ------ + ValueError + If the number of data dimensions is greater than 2. + ValueError + If the number of labels is inconsistent with the data. + """ # ensure numpy data = np.array(data) + # check dimension + if data.ndim > 2: + raise ValueError("Number of data dimensions cannot be greater than 2.") + # build header header = "Simple Text Format\n" - header += "Sampling Rate (Hz): %0.2f\n" % sampling_rate + header += "Sampling Rate (Hz):= %0.2f\n" % sampling_rate if resolution is not None: - header += "Resolution: %d\n" % resolution + header += "Resolution:= %d\n" % resolution if date is not None: if isinstance(date, basestring): - header += "Date: %s\n" % date + header += "Date:= %s\n" % date elif isinstance(date, datetime.datetime): - header += "Date: %s\n" % date.isoformat() + header += "Date:= %s\n" % date.isoformat() else: ct = datetime.datetime.utcnow().isoformat() - header += "Date: %s\n" % ct + header += "Date:= %s\n" % ct # data type - header += "Data Type: %s" % data.dtype + header += "Data Type:= %s\n" % data.dtype + + # labels + if data.ndim == 1: + ncols = 1 + elif data.ndim == 2: + ncols = data.shape[1] + + if labels is None: + labels = ['%d' % i for i in xrange(ncols)] + elif len(labels) != ncols: + raise ValueError("Inconsistent number of labels.") + + header += "Labels:= %s" % '\t'.join(labels) # normalize path path = utils.normpath(path) @@ -361,14 +387,14 @@ def load_txt(path): # extract header mdata_tmp = {} - fields = ['Sampling Rate', 'Resolution', 'Date', 'Data Type'] + fields = ['Sampling Rate', 'Resolution', 'Date', 'Data Type', 'Labels'] values = [] for item in lines: if '#' in item: # parse comment for f in fields: if f in item: - mdata_tmp[f] = item.split(': ')[1].strip() + mdata_tmp[f] = item.split(':= ')[1].strip() fields.remove(f) break else: @@ -394,6 +420,11 @@ def load_txt(path): mdata['date'] = d except (KeyError, ValueError): pass + try: + labels = mdata_tmp['Labels'].split('\t') + mdata['labels'] = labels + except KeyError: + pass # load array data = np.genfromtxt(values, dtype=dtype, delimiter='\t')