diff --git a/boututils/datafile.py b/boututils/datafile.py index 00bb987..f2fb5ba 100644 --- a/boututils/datafile.py +++ b/boututils/datafile.py @@ -470,6 +470,13 @@ def dimlen(d): def _bout_type_from_dimensions(self, varname): dims = self.dimensions(varname) + + if any("char" in d for d in dims): + if 't' in dims: + return "string_t" + else: + return "string" + dims_dict = { ('t', 'x', 'y', 'z'): "Field3D_t", ('t', 'x', 'y'): "Field2D_t", @@ -484,7 +491,24 @@ def _bout_type_from_dimensions(self, varname): return dims_dict.get(dims, None) - def _bout_dimensions_from_type(self, bout_type): + def _bout_dimensions_from_var(self, data): + try: + bout_type = data.attributes["bout_type"] + except AttributeError: + defdims_list = [(), + ('t',), + ('x', 'y'), + ('x', 'y', 'z'), + ('t', 'x', 'y', 'z')] + return defdims_list[len(np.shape(data))] + + if bout_type == "string_t": + nt, string_length = data.shape + return ('t', "char" + str(string_length),) + elif bout_type == "string": + string_length = len(data) + return ("char" + str(string_length),) + dims_dict = { "Field3D_t": ('t', 'x', 'y', 'z'), "Field2D_t": ('t', 'x', 'y'), @@ -541,15 +565,7 @@ def write(self, name, data, info=False): # Not found, so add. # Get dimensions - try: - defdims = self._bout_dimensions_from_type(data.attributes['bout_type']) - except AttributeError: - defdims_list = [(), - ('t',), - ('x', 'y'), - ('x', 'y', 'z'), - ('t', 'x', 'y', 'z')] - defdims = defdims_list[len(s)] + defdims = self._bout_dimensions_from_var(data) def find_dim(dim): # Find a dimension with given name and size @@ -786,11 +802,13 @@ def dimensions(self, varname): "FieldPerp_t": ('t', 'x', 'z'), "Field2D_t": ('t', 'x', 'y'), "scalar_t": ('t',), + "string_t": ('t', 'char'), "Field3D": ('x', 'y', 'z'), "FieldPerp": ('x', 'z'), "Field2D": ('x', 'y'), "ArrayX": ('x',), "scalar": (), + "string": ('char',), } try: return dims_dict[bout_type] @@ -892,7 +910,7 @@ def write(self, name, data, info=False): print("Creating variable '" + name + "' with bout_type '" + bout_type + "'") - if bout_type in ["Field3D_t", "Field2D_t", "FieldPerp_t", "scalar_t"]: + if bout_type[-2:] == "_t": # time evolving fields shape = list(data.shape) # set time dimension to None to make unlimited