Source code for gusto.core.io

"""Provides the model's IO, which controls input, output and diagnostics."""

from os import path, makedirs
import itertools
from netCDF4 import Dataset
import sys
import time
from gusto.diagnostics import Diagnostics, CourantNumber
from gusto.core.meshes import get_flat_latlon_mesh
from firedrake import (Function, functionspaceimpl, Constant,
                       DumbCheckpoint, FILE_CREATE, FILE_READ, CheckpointFile)
from firedrake.output import VTKFile
from pyop2.mpi import MPI
import numpy as np
from gusto.core.logging import logger, update_logfile_location

__all__ = ["pick_up_mesh", "IO"]


class GustoIOError(IOError):
    pass


[docs] def pick_up_mesh(output, mesh_name): """ Picks up a checkpointed mesh. This must be the first step of any model being picked up from a checkpointing run. Args: output (:class:`OutputParameters`): holds and describes the options for outputting. mesh_name (str): the name of the mesh to be picked up. The default names used by Firedrake are "firedrake_default" for non-extruded meshes, or "firedrake_default_extruded" for extruded meshes. Returns: :class:`Mesh`: the mesh to be used by the model. """ # Open the checkpointing file for writing dumpdir = None if output.checkpoint_pickup_filename is not None: chkfile = output.checkpoint_pickup_filename else: dumpdir = path.join("results", output.dirname) chkfile = path.join(dumpdir, "chkpt.h5") with CheckpointFile(chkfile, 'r') as chk: mesh = chk.load_mesh(mesh_name) if dumpdir: update_logfile_location(dumpdir, mesh.comm) return mesh
class PointDataOutput(object): """Object for outputting field point data.""" def __init__(self, filename, field_points, description, field_creator, comm, tolerance=None, create=True): """ Args: filename (str): name of file to output to. field_points (list): some iterable of pairs, matching fields with arrays of evaluation points: (field_name, evaluation_points). description (str): a description of the simulation to be included in the output. field_creator (:class:`FieldCreator`): the field creator, used to determine the datatype and shape of fields. comm (:class:`MPI.Comm`): MPI communicator. tolerance (float, optional): tolerance to use for the evaluation of fields at points. Defaults to None. create (bool, optional): whether the output file needs creating, or if it already exists. Defaults to True. """ # Overwrite on creation. self.dump_count = 0 self.filename = filename self.field_points = field_points self.tolerance = tolerance self.comm = comm if self.comm.size > 1: raise GustoIOError("PointDataOutput does not work in parallel") if not create: return if self.comm.rank == 0: with Dataset(filename, "w") as dataset: dataset.description = "Point data for simulation {desc}".format(desc=description) dataset.history = "Created {t}".format(t=time.ctime()) # FIXME add versioning information. dataset.source = "Output from Gusto model" # Appendable dimension, timesteps in the model dataset.createDimension("time", None) var = dataset.createVariable("time", np.float64, ("time")) var.units = "seconds" # Now create the variable group for each field for field_name, points in field_points: group = dataset.createGroup(field_name) npts, dim = points.shape group.createDimension("points", npts) group.createDimension("geometric_dimension", dim) var = group.createVariable("points", points.dtype, ("points", "geometric_dimension")) var[:] = points # Get the UFL shape of the field field_shape = field_creator(field_name).ufl_shape # Number of geometric dimension occurences should be the same as the length of the UFL shape field_len = len(field_shape) field_count = field_shape.count(dim) assert field_len == field_count, "Geometric dimension occurrences do not match UFL shape" # Create the variable with the required shape dimensions = ("time", "points") + field_count*("geometric_dimension",) group.createVariable(field_name, field_creator(field_name).dat.dtype, dimensions) def dump(self, field_creator, t): """ Evaluate and output field data at points. Args: field_creator (:class:`FieldCreator`): gives access to the fields. t (float): simulation time at which the output occurs. """ val_list = [] for field_name, points in self.field_points: val_list.append((field_name, np.asarray(field_creator(field_name).at(points, tolerance=self.tolerance)))) if self.comm.rank == 0: with Dataset(self.filename, "a") as dataset: # Add new time index dataset.variables["time"][self.dump_count] = t for field_name, vals in val_list: group = dataset.groups[field_name] var = group.variables[field_name] var[self.dump_count, :] = vals self.dump_count += 1 class DiagnosticsOutput(object): """Object for outputting global diagnostic data.""" def __init__(self, filename, diagnostics, description, comm, create=True): """ Args: filename (str): name of file to output to. diagnostics (:class:`Diagnostics`): the object holding and controlling the diagnostic evaluation. description (str): a description of the simulation to be included in the output. comm (:class:`MPI.Comm`): MPI communicator. create (bool, optional): whether the output file needs creating, or if it already exists. Defaults to True. """ self.filename = filename self.diagnostics = diagnostics self.comm = comm if not create: return if self.comm.rank == 0: with Dataset(filename, "w") as dataset: dataset.description = "Diagnostics data for simulation {desc}".format(desc=description) dataset.history = "Created {t}".format(t=time.ctime()) dataset.source = "Output from Gusto model" dataset.createDimension("time", None) var = dataset.createVariable("time", np.float64, ("time", )) var.units = "seconds" for name in diagnostics.fields: group = dataset.createGroup(name) for diagnostic in diagnostics.available_diagnostics: group.createVariable(diagnostic, np.float64, ("time", )) def dump(self, state_fields, t): """ Output the global diagnostics. state_fields (:class:`StateFields`): the model's field container. t (float): simulation time at which the output occurs. """ diagnostics = [] for fname in self.diagnostics.fields: field = state_fields(fname) for dname in self.diagnostics.available_diagnostics: diagnostic = getattr(self.diagnostics, dname) diagnostics.append((fname, dname, diagnostic(field))) if self.comm.rank == 0: with Dataset(self.filename, "a") as dataset: idx = dataset.dimensions["time"].size dataset.variables["time"][idx:idx + 1] = t for fname, dname, value in diagnostics: group = dataset.groups[fname] var = group.variables[dname] var[idx:idx + 1] = value
[docs] class IO(object): """Controls the model's input, output and diagnostics.""" def __init__(self, domain, output, diagnostics=None, diagnostic_fields=None): """ Args: domain (:class:`Domain`): the model's domain object, containing the mesh and the compatible function spaces. output (:class:`OutputParameters`): holds and describes the options for outputting. diagnostics (:class:`Diagnostics`, optional): object holding and controlling the model's diagnostics. Defaults to None. diagnostic_fields (list, optional): an iterable of `DiagnosticField` objects. Defaults to None. Raises: RuntimeError: if no output is provided. TypeError: if `dt` cannot be cast to a :class:`Constant`. """ self.domain = domain self.mesh = domain.mesh self.output = output if diagnostics is not None: self.diagnostics = diagnostics else: self.diagnostics = Diagnostics() if diagnostic_fields is not None: self.diagnostic_fields = diagnostic_fields else: self.diagnostic_fields = [] if self.output.dumplist is None: self.output.dumplist = [] self.dumpdir = None self.dumpfile = None self.to_pick_up = None if output.log_courant: self.courant_max = Constant(0.0)
[docs] def log_parameters(self, equation): """ Logs an equation's physical parameters that take non-default values. Args: equation (:class:`PrognosticEquation`): the model's equation which contains any physical parameters used in the model run. """ if hasattr(equation, 'parameters') and equation.parameters is not None: logger.info("Physical parameters that take non-default values:") logger.info(", ".join("%s: %s" % (k, float(v)) for (k, v) in vars(equation.parameters).items()))
[docs] def setup_log_courant(self, state_fields, name='u', component="whole", expression=None): """ Sets up Courant number diagnostics to be logged. Args: state_fields (:class:`StateFields`): the model's field container. name (str, optional): the name of the field to log the Courant number of. Defaults to 'u'. component (str, optional): the component of the velocity to use for calculating the Courant number. Valid values are "whole", "horizontal" or "vertical". Defaults to "whole". expression (:class:`ufl.Expr`, optional): expression of velocity field to take Courant number of. Defaults to None, in which case the "name" argument must correspond to an existing field. """ if self.output.log_courant: diagnostic_names = [diagnostic.name for diagnostic in self.diagnostic_fields] courant_name = None if name == 'u' else name # Set up diagnostic if it hasn't already been if courant_name not in diagnostic_names and 'u' in state_fields._field_names: if expression is None: diagnostic = CourantNumber(to_dump=False, component=component) elif expression is not None: diagnostic = CourantNumber(velocity=expression, component=component, name=courant_name, to_dump=False) self.diagnostic_fields.append(diagnostic) diagnostic.setup(self.domain, state_fields) self.diagnostics.register(diagnostic.name)
[docs] def log_courant(self, state_fields, name='u', component="whole", message=None): """ Logs the maximum Courant number value. Args: state_fields (:class:`StateFields`): the model's field container. name (str, optional): the name of the field to log the Courant number of. Defaults to 'u'. component (str, optional): the component of the velocity to use for calculating the Courant number. Valid values are "whole", "horizontal" or "vertical". Defaults to "whole". message (str, optional): an extra message to be logged. Defaults to None. """ if self.output.log_courant and 'u' in state_fields._field_names: diagnostic_names = [diagnostic.name for diagnostic in self.diagnostic_fields] courant_name = 'CourantNumber' if name == 'u' else 'CourantNumber_'+name if component != 'whole': courant_name += '_'+component courant_idx = diagnostic_names.index(courant_name) courant_diagnostic = self.diagnostic_fields[courant_idx] courant_diagnostic.compute() courant_field = state_fields(courant_name) courant_max = self.diagnostics.max(courant_field) if message is None: logger.info(f'Max Courant: {courant_max:.2e}') else: logger.info(f'Max Courant {message}: {courant_max:.2e}') if component == 'whole': # TODO: this will update the Courant number more than we need to # and possibly with the wrong Courant number # we could make self.courant_max a dict with keys depending on # the field to take the Courant number of self.courant_max.assign(courant_max)
[docs] def setup_diagnostics(self, state_fields): """ Prepares the I/O for computing the model's global diagnostics and diagnostic fields. Args: state_fields (:class:`StateFields`): the model's field container. """ diagnostic_names = [diagnostic.name for diagnostic in self.diagnostic_fields] non_diagnostics = [fname for fname in state_fields._field_names if state_fields.field_type(fname) != "diagnostic" or fname not in diagnostic_names] # Set up any reference or initial fields that are necessary for diagnostics all_required_fields = {r for d in self.diagnostic_fields for r in d.required_fields} ref_fields = list(filter(lambda fname: fname[-4:] == '_bar', all_required_fields)) init_fields = list(filter(lambda fname: fname[-5:] == '_init', all_required_fields)) non_diagnostics = non_diagnostics + ref_fields + init_fields # Set up order for diagnostic fields -- filter out non-diagnostic fields field_deps = [(d, sorted(set(d.required_fields).difference(non_diagnostics),)) for d in self.diagnostic_fields] schedule = topo_sort(field_deps) self.diagnostic_fields = schedule # Set up and register all diagnostic fields for diagnostic in self.diagnostic_fields: diagnostic.setup(self.domain, state_fields) self.diagnostics.register(diagnostic.name) # Register fields for global diagnostics # TODO: it should be possible to specify which global diagnostics are used for fname in state_fields._field_names: if fname in state_fields.to_dump: self.diagnostics.register(fname)
[docs] def setup_dump(self, state_fields, t, pick_up=False): """ Sets up a series of things used for outputting. This prepares the model for outputting. First it checks for the existence the specified outputting directory, so prevent it being overwritten unintentionally. It then sets up the output files and the checkpointing file. Args: state_fields (:class:`StateFields`): the model's field container. t (float): the current model time. pick_up (bool, optional): whether to pick up the model's initial state from a checkpointing file. Defaults to False. Raises: GustoIOError: if the results directory already exists, and the model is not picking up or running in test mode. """ # Use 0 for okay, 1 for internal exception 2 for external exception raise_parallel_exception = 0 error = None if any([self.output.dump_vtus, self.output.dump_nc, self.output.dumplist_latlon, self.output.dump_diagnostics, self.output.point_data, self.output.checkpoint and not pick_up]): # setup output directory and check that it does not already exist self.dumpdir = path.join("results", self.output.dirname) running_tests = '--running-tests' in sys.argv or "pytest" in self.output.dirname # Raising exceptions needs to be done in parallel if self.mesh.comm.Get_rank() == 0: # Create results directory if it doesn't already exist if not path.exists(self.dumpdir): try: makedirs(self.dumpdir) except OSError as e: error = e raise_parallel_exception = 2 elif not (running_tests or pick_up): # Throw an error if directory already exists, unless we # are picking up or running tests raise_parallel_exception = 1 # Gather errors from each rank and raise appropriate error everywhere # This allreduce also ensures that all ranks are in sync wrt the results dir raise_exception = self.mesh.comm.allreduce(raise_parallel_exception, op=MPI.MAX) if raise_exception == 1: raise GustoIOError(f'results directory {self.dumpdir} already exists') elif raise_exception == 2: if error: raise error else: raise OSError('Check error message on rank 0') update_logfile_location(self.dumpdir, self.mesh.comm) if self.output.dump_vtus or self.output.dump_nc: # make list of fields to dump self.to_dump = [f for f in state_fields.fields if f.name() in state_fields.to_dump] # make dump counter self.dumpcount = itertools.count() # if picking-up, don't do initial dump if pick_up: next(self.dumpcount) if self.output.dump_vtus: # setup pvd output file outfile_pvd = path.join(self.dumpdir, "field_output.pvd") self.pvd_dumpfile = VTKFile( outfile_pvd, project_output=self.output.project_fields, comm=self.mesh.comm) if self.output.dump_nc: self.nc_filename = path.join(self.dumpdir, "field_output.nc") space_names = sorted(set([field.function_space().name for field in self.to_dump])) for space_name in space_names: self.domain.coords.register_space(self.domain, space_name) if pick_up: # Pick up t idx if self.mesh.comm.Get_rank() == 0: nc_field_file = Dataset(self.nc_filename, 'r') self.field_t_idx = len(nc_field_file['time'][:]) nc_field_file.close() else: self.field_t_idx = None # Send information to other processors self.field_t_idx = self.mesh.comm.bcast(self.field_t_idx, root=0) else: # File needs creating self.create_nc_dump(self.nc_filename, space_names) # if there are fields to be dumped in latlon coordinates, # setup the latlon coordinate mesh and make output file if len(self.output.dumplist_latlon) > 0: mesh_ll = get_flat_latlon_mesh(self.mesh) outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd") self.dumpfile_ll = VTKFile(outfile_ll, project_output=self.output.project_fields, comm=self.mesh.comm) # make functions on latlon mesh, as specified by dumplist_latlon self.to_dump_latlon = [] for name in self.output.dumplist_latlon: f = state_fields(name) field = Function( functionspaceimpl.WithGeometry.create( f.function_space(), mesh_ll), val=f.topological, name=name+'_ll') self.to_dump_latlon.append(field) # we create new netcdf files to write to, unless pick_up=True and they # already exist, in which case we just need the filenames if self.output.dump_diagnostics: diagnostics_filename = self.dumpdir+"/diagnostics.nc" to_create = not (path.isfile(diagnostics_filename) and pick_up) self.diagnostic_output = DiagnosticsOutput(diagnostics_filename, self.diagnostics, self.output.dirname, self.mesh.comm, create=to_create) # if picking-up, don't do initial dump self.diagcount = itertools.count() if pick_up: next(self.diagcount) if len(self.output.point_data) > 0: # set up point data output pointdata_filename = self.dumpdir+"/point_data.nc" to_create = not (path.isfile(pointdata_filename) and pick_up) self.pointdata_output = PointDataOutput(pointdata_filename, self.output.point_data, self.output.dirname, state_fields, self.mesh.comm, self.output.tolerance, create=to_create) # make point data dump counter self.pddumpcount = itertools.count() # if picking-up, don't do initial dump if pick_up: next(self.pddumpcount) # set frequency of point data output - defaults to # dumpfreq if not set by user if self.output.pddumpfreq is None: self.output.pddumpfreq = self.output.dumpfreq # if we want to checkpoint, set up the checkpointing if self.output.checkpoint: if self.output.checkpoint_method == 'dumbcheckpoint': # should have already picked up, so can create a new file self.chkpt = DumbCheckpoint(path.join(self.dumpdir, "chkpt"), mode=FILE_CREATE) elif self.output.checkpoint_method == 'checkpointfile': # should have already picked up, so can create a new file self.chkpt_path = path.join(self.dumpdir, "chkpt.h5") else: raise ValueError(f'checkpoint_method {self.output.checkpoint_method} not supported') # make list of fields to pick_up (this doesn't include # diagnostic fields) self.to_pick_up = [fname for fname in state_fields.to_pick_up] # make a checkpoint counter self.chkptcount = itertools.count() # if picking-up, don't do initial dump if pick_up: next(self.chkptcount) # dump initial fields if not pick_up: self.dump(state_fields, t, step=1)
[docs] def pick_up_from_checkpoint(self, state_fields): """ Picks up the model's variables from a checkpoint file. Args: state_fields (:class:`StateFields`): the model's field container. Returns: float: the checkpointed model time. """ # -------------------------------------------------------------------- # # Preparation for picking up # -------------------------------------------------------------------- # # Make list of fields that must be picked up if self.to_pick_up is None: self.to_pick_up = [fname for fname in state_fields.to_pick_up] # Set dumpdir if has not been done already if self.dumpdir is None: self.dumpdir = path.join("results", self.output.dirname) update_logfile_location(self.dumpdir, self.mesh.comm) # Need to pick up reference profiles, but don't know which are stored possible_ref_profiles = [] reference_profiles = [] for field_name, field_type in zip(state_fields._field_names, state_fields._field_types): if field_type != 'reference': possible_ref_profiles.append(field_name) # -------------------------------------------------------------------- # # Pick up fields # -------------------------------------------------------------------- # if self.output.checkpoint: # Open the checkpointing file for writing if self.output.checkpoint_pickup_filename is not None: chkfile = self.output.checkpoint_pickup_filename elif self.output.checkpoint_method == 'dumbcheckpoint': chkfile = path.join(self.dumpdir, "chkpt") elif self.output.checkpoint_method == 'checkpointfile': chkfile = path.join(self.dumpdir, "chkpt.h5") if self.output.checkpoint_method == 'dumbcheckpoint': with DumbCheckpoint(chkfile, mode=FILE_READ) as chk: # Recover compulsory fields from the checkpoint for field_name in self.to_pick_up: chk.load(state_fields(field_name), name=field_name) # Read in reference profiles -- failures are allowed here for field_name in possible_ref_profiles: ref_name = f'{field_name}_bar' ref_field = Function(state_fields(field_name).function_space(), name=ref_name) try: chk.load(ref_field, name=ref_name) reference_profiles.append((field_name, ref_field)) # Field exists, so add to to_pick_up self.to_pick_up.append(ref_name) except RuntimeError: pass # Try to pick up number of initial steps for multi level scheme # Not compulsory so errors allowed try: initial_steps = chk.read_attribute("/", "initial_steps") except AttributeError: initial_steps = None # Finally pick up time and step number t = chk.read_attribute("/", "time") step = chk.read_attribute("/", "step") else: with CheckpointFile(chkfile, 'r') as chk: mesh = self.domain.mesh # Recover compulsory fields from the checkpoint for field_name in self.to_pick_up: field = chk.load_function(mesh, field_name) state_fields(field_name).assign(field) # Read in reference profiles -- failures are allowed here for field_name in possible_ref_profiles: ref_name = f'{field_name}_bar' try: ref_field = chk.load_function(mesh, ref_name) reference_profiles.append((field_name, ref_field)) # Field exists, so add to to_pick_up self.to_pick_up.append(ref_name) except RuntimeError: pass # Try to pick up number of initial steps for multi level scheme # Not compulsory so errors allowed if chk.has_attr("/", "initial_steps"): initial_steps = chk.get_attr("/", "initial_steps") else: initial_steps = None # Finally pick up time t = chk.get_attr("/", "time") step = chk.get_attr("/", "step") # If we have picked up from a non-standard file, reset this name # so that we will checkpoint using normal file name from now on self.output.checkpoint_pickup_filename = None else: raise ValueError("Must set checkpoint True if picking up") # Prevent any steady-state diagnostics overwriting their original fields for diagnostic_field in self.diagnostic_fields: if hasattr(diagnostic_field, "init_field_set"): diagnostic_field.init_field_set = True return t, reference_profiles, step, initial_steps
[docs] def dump(self, state_fields, t, step, initial_steps=None): """ Dumps all of the required model output. This includes point data, global diagnostics and general field data to paraview data files. Also writes the model's prognostic variables to a checkpoint file if specified. Args: state_fields (:class:`StateFields`): the model's field container. t (float): the simulation's current time. step (int): the number of time steps. initial_steps (int, optional): the number of initial time steps completed by a multi-level time scheme. Defaults to None. """ output = self.output # Diagnostics: # Compute diagnostic fields for field in self.diagnostic_fields: field.compute() if output.dump_diagnostics and (next(self.diagcount) % output.diagfreq) == 0: # Output diagnostic data self.diagnostic_output.dump(state_fields, t) if len(output.point_data) > 0 and (next(self.pddumpcount) % output.pddumpfreq) == 0: # Output pointwise data self.pointdata_output.dump(state_fields, t) # Dump all the fields to the checkpointing file (backup version) if output.checkpoint and (next(self.chkptcount) % output.chkptfreq) == 0: if self.output.checkpoint_method == 'dumbcheckpoint': for field_name in self.to_pick_up: self.chkpt.store(state_fields(field_name), name=field_name) self.chkpt.write_attribute("/", "time", t) self.chkpt.write_attribute("/", "step", step) if initial_steps is not None: self.chkpt.write_attribute("/", "initial_steps", initial_steps) else: with CheckpointFile(self.chkpt_path, 'w') as chk: chk.save_mesh(self.domain.mesh) for field_name in self.to_pick_up: chk.save_function(state_fields(field_name), name=field_name) chk.set_attr("/", "time", t) chk.set_attr("/", "step", step) if initial_steps is not None: chk.set_attr("/", "initial_steps", initial_steps) if (next(self.dumpcount) % output.dumpfreq) == 0: if output.dump_nc: # dump fields self.write_nc_dump(t) if output.dump_vtus: # dump fields self.pvd_dumpfile.write(*self.to_dump) # dump fields on latlon mesh if len(output.dumplist_latlon) > 0: self.dumpfile_ll.write(*self.to_dump_latlon)
[docs] def create_nc_dump(self, filename, space_names): my_rank = self.mesh.comm.Get_rank() self.field_t_idx = 0 if my_rank == 0: nc_field_file = Dataset(filename, 'w') nc_field_file.createDimension('time', None) nc_field_file.createVariable('time', float, ('time',)) # Add mesh metadata for metadata_key, metadata_value in self.domain.metadata.items(): # If the metadata is None or a Boolean, try converting to string # This is because netCDF can't take these types as options if type(metadata_value) in [type(None), type(True)]: output_value = str(metadata_value) else: output_value = metadata_value # Get the type from the metadata itself nc_field_file.createVariable(metadata_key, type(output_value), []) nc_field_file.variables[metadata_key][0] = output_value # Add coordinates if they are not already in the file for space_name in space_names: if space_name not in self.domain.coords.chi_coords.keys(): # Space not registered # TODO: we should fail here, but currently there are some spaces # that we can't output for so instead just skip outputting pass else: coord_fields = self.domain.coords.global_chi_coords[space_name] num_points = len(self.domain.coords.global_chi_coords[space_name][0]) nc_field_file.createDimension('coords_'+space_name, num_points) for (coord_name, coord_field) in zip(self.domain.coords.coords_name, coord_fields): nc_field_file.createVariable(coord_name+'_'+space_name, float, 'coords_'+space_name) nc_field_file.variables[coord_name+'_'+space_name][:] = coord_field[:] # Create variable for storing the field values for field in self.to_dump: field_name = field.name() space_name = field.function_space().name if space_name not in self.domain.coords.chi_coords.keys(): # Space not registered # TODO: we should fail here, but currently there are some spaces # that we can't output for so instead just skip outputting logger.warning(f'netCDF outputting for space {space_name} ' + 'not yet implemented, so unable to output ' + f'{field_name} field') else: nc_field_file.createGroup(field_name) nc_field_file[field_name].createVariable('field_values', float, ('coords_'+space_name, 'time')) nc_field_file.close()
[docs] def write_nc_dump(self, t): comm = self.mesh.comm my_rank = comm.Get_rank() comm_size = comm.Get_size() # Open file to add time if my_rank == 0: nc_field_file = Dataset(self.nc_filename, 'a') nc_field_file['time'][self.field_t_idx] = t # Loop through output field data here num_fields = len(self.to_dump) for i, field in enumerate(self.to_dump): field_name = field.name() space_name = field.function_space().name if space_name not in self.domain.coords.chi_coords.keys(): # Space not registered # TODO: we should fail here, but currently there are some spaces # that we can't output for so instead just skip outputting pass # -------------------------------------------------------- # # Scalar elements # -------------------------------------------------------- # else: j = 0 # For most processors send data to first processor if my_rank != 0: # Make a tag to uniquely identify this call my_tag = comm_size*(num_fields*j + i) + my_rank comm.send(field.dat.data_ro[:], dest=0, tag=my_tag) else: # Set up array to store full data in total_data_size = self.domain.coords.parallel_array_lims[space_name][comm_size-1][1]+1 single_proc_data = np.zeros(total_data_size) # Get data for this processor first (low_lim, up_lim) = self.domain.coords.parallel_array_lims[space_name][my_rank][:] single_proc_data[low_lim:up_lim+1] = field.dat.data_ro[:] # Receive data from other processors for procid in range(1, comm_size): my_tag = comm_size*(num_fields*j + i) + procid incoming_data = comm.recv(source=procid, tag=my_tag) (low_lim, up_lim) = self.domain.coords.parallel_array_lims[space_name][procid][:] single_proc_data[low_lim:up_lim+1] = incoming_data[:] # Store whole field data nc_field_file[field_name].variables['field_values'][:, self.field_t_idx] = single_proc_data[:] if my_rank == 0: nc_field_file.close() self.field_t_idx += 1
def topo_sort(field_deps): """ Perform a topological sort to determine the order to evaluate diagnostics. Args: field_deps (list): a list of tuples, pairing diagnostic fields with the fields that they are to be evaluated from. Raises: RuntimeError: if there is a cyclic dependency in the diagnostic fields. Returns: list: a list specifying the order in which to evaluate the diagnostics. """ name2field = dict((f.name, f) for f, _ in field_deps) # map node: (input_deps, output_deps) graph = dict((f.name, (list(deps), [])) for f, deps in field_deps) roots = [] for f, input_deps in field_deps: if len(input_deps) == 0: # No dependencies, candidate for evaluation roots.append(f.name) for d in input_deps: # add f as output dependency graph[d][1].append(f.name) schedule = [] while roots: n = roots.pop() schedule.append(n) output_deps = list(graph[n][1]) for m in output_deps: # Remove edge graph[m][0].remove(n) graph[n][1].remove(m) # If m now as no input deps, candidate for evaluation if len(graph[m][0]) == 0: roots.append(m) if any(len(i) for i, _ in graph.values()): cycle = "\n".join("%s -> %s" % (f, i) for f, (i, _) in graph.items() if f not in schedule) raise RuntimeError("Field dependencies have a cycle:\n\n%s" % cycle) return list(map(name2field.__getitem__, schedule))