Source code for gusto.timestepping.timestepper

"""Defines the basic timestepper objects."""

from abc import ABCMeta, abstractmethod, abstractproperty
from firedrake import Function, Projector, split
from firedrake.fml import drop, Term, LabelledForm
from pyop2.profiling import timed_stage
from gusto.equations import PrognosticEquationSet
from gusto.core import TimeLevelFields, StateFields
from gusto.core.io import TimeData
from gusto.core.labels import transport, diffusion, prognostic, transporting_velocity
from gusto.core.logging import logger, DEBUG
from gusto.time_discretisation.time_discretisation import ExplicitTimeDiscretisation
from gusto.spatial_methods.transport_methods import TransportMethod
import ufl
import numpy as np

__all__ = ["BaseTimestepper", "Timestepper", "PrescribedTransport"]


[docs] class BaseTimestepper(object, metaclass=ABCMeta): """Base class for timesteppers.""" def __init__(self, equation, io): """ Args: equation (:class:`PrognosticEquation`): the prognostic equation. io (:class:`IO`): the model's object for controlling input/output. """ self.equation = equation self.io = io self.dt = self.equation.domain.dt self.t = self.equation.domain.t self.reference_profiles_initialised = False self.last_ref_update_time = None self.setup_fields() self.setup_scheme() self.io.log_parameters(equation) @abstractproperty def transporting_velocity(self): return NotImplementedError
[docs] @abstractmethod def setup_fields(self): """Set up required fields. Must be implemented in child classes""" pass
[docs] @abstractmethod def setup_scheme(self): """Set up required scheme(s). Must be implemented in child classes""" pass
[docs] @abstractmethod def timestep(self): """Defines the timestep. Must be implemented in child classes""" return NotImplementedError
[docs] def set_initial_timesteps(self, num_steps): """Sets the number of initial time steps for a multi-level scheme.""" can_set = (hasattr(self, 'scheme') and hasattr(self.scheme, 'initial_timesteps') and num_steps is not None) if can_set: self.scheme.initial_timesteps = num_steps
[docs] def get_initial_timesteps(self): """Gets the number of initial time steps from a multi-level scheme.""" can_get = (hasattr(self, 'scheme') and hasattr(self.scheme, 'initial_timesteps')) # Return None if this is not applicable return self.scheme.initial_timesteps if can_get else None
[docs] def setup_equation(self, equation): """ Sets up the spatial methods for an equation, by the setting the forms used for transport/diffusion in the equation. Args: equation (:class:`PrognosticEquation`): the equation that the transport method is to be applied to. """ # For now, we only have methods for transport and diffusion for term_label in [transport, diffusion]: # ---------------------------------------------------------------- # # Check that appropriate methods have been provided # ---------------------------------------------------------------- # # Extract all terms corresponding to this type of term residual = equation.residual.label_map( lambda t: t.has_label(term_label), map_if_false=drop ) variables = [t.get(prognostic) for t in residual.terms] methods = list(filter(lambda t: t.term_label == term_label, self.spatial_methods)) method_variables = [method.variable for method in methods] for variable in variables: if variable not in method_variables: message = f'Variable {variable} has a {term_label.label} ' \ + 'term but no method for this has been specified. ' \ + 'Using default form for this term' logger.warning(message) # -------------------------------------------------------------------- # # Check that appropriate methods have been provided # -------------------------------------------------------------------- # # Replace forms in equation if self.spatial_methods is not None: for method in self.spatial_methods: method.replace_form(equation)
[docs] def setup_transporting_velocity(self, scheme): """ Set up the time discretisation by replacing the transporting velocity used by the appropriate one for this time loop. Args: scheme (:class:`TimeDiscretisation`): the time discretisation whose transport term should be replaced with the transport term of this discretisation. """ if self.transporting_velocity == "prognostic" and "u" in self.fields._field_names: # Use the prognostic wind variable as the transporting velocity u_idx = self.equation.field_names.index('u') uadv = split(self.equation.X)[u_idx] else: uadv = self.transporting_velocity scheme.residual = scheme.residual.label_map( lambda t: t.has_label(transporting_velocity), map_if_true=lambda t: Term(ufl.replace(t.form, {t.get(transporting_velocity): uadv}), t.labels) ) scheme.residual = transporting_velocity.update_value(scheme.residual, uadv) # Now also replace transporting velocity in the terms that are # contained in labels for idx, t in enumerate(scheme.residual.terms): if t.has_label(transporting_velocity): for label in t.labels.keys(): if type(t.labels[label]) is LabelledForm: t.labels[label] = t.labels[label].label_map( lambda s: s.has_label(transporting_velocity), map_if_true=lambda s: Term(ufl.replace( s.form, {s.get(transporting_velocity): uadv}), s.labels ) ) scheme.residual.terms[idx].labels[label] = \ transporting_velocity.update_value(t.labels[label], uadv)
[docs] def log_timestep(self): """ Logs the start of a time step. """ logger.info('') logger.info('='*40) logger.info(f'at start of timestep {self.step}, t={float(self.t)}, dt={float(self.dt)}')
[docs] def log_field_stats(self): """ Logs some field stats, which can be useful for debugging. """ current_log_level = logger.getEffectiveLevel() if current_log_level > DEBUG: return for field_name in self.fields._field_names: field_data = self.fields(field_name).dat.data_ro # Mixed functions don't have min or max routines, and are less # useful, so try to eliminate these by only logging fields with # a 1-dimension array of data if type(field_data) is np.ndarray and len(np.shape(field_data)) == 1: min_val = field_data.min() max_val = field_data.max() logger.debug(f'{field_name}, min: {min_val}, max: {max_val}')
[docs] def run(self, t, tmax, pick_up=False): """ Runs the model for the specified time, from t to tmax Args: t (float): the start time of the run tmax (float): the end time of the run pick_up: (bool): specify whether to pick_up from a previous run """ # Set up diagnostics, which may set up some fields necessary to pick up self.io.setup_diagnostics(self.fields) self.io.setup_log_courant(self.fields) if self.equation.domain.mesh.extruded: self.io.setup_log_courant(self.fields, component='horizontal') self.io.setup_log_courant(self.fields, component='vertical') if self.transporting_velocity != "prognostic": self.io.setup_log_courant(self.fields, name='transporting_velocity', expression=self.transporting_velocity) if pick_up: # Pick up fields, and return other info to be picked up time_data, reference_profiles = self.io.pick_up_from_checkpoint(self.fields) t = time_data.t self.step = time_data.step initial_timesteps = time_data.initial_steps last_ref_update_time = time_data.last_ref_update_time self.set_reference_profiles(reference_profiles, last_ref_update_time) self.set_initial_timesteps(initial_timesteps) else: self.step = 1 # Set up dump, which may also include an initial dump with timed_stage("Dump output"): logger.debug('Dumping output to disk') self.io.setup_dump(self.fields, t, pick_up) self.log_field_stats() self.t.assign(t) # Time loop while float(self.t) < tmax - 0.5*float(self.dt): self.log_timestep() self.x.update() self.io.log_courant(self.fields) if self.equation.domain.mesh.extruded: self.io.log_courant(self.fields, component='horizontal', message='horizontal') self.io.log_courant(self.fields, component='vertical', message='vertical') self.timestep() self.t.assign(float(self.t) + float(self.dt)) self.step += 1 with timed_stage("Dump output"): time_data = TimeData( t=float(self.t), step=self.step, initial_steps=self.get_initial_timesteps(), last_ref_update_time=self.last_ref_update_time ) self.io.dump(self.fields, time_data) self.log_field_stats() if self.io.output.checkpoint and self.io.output.checkpoint_method == 'dumbcheckpoint': self.io.chkpt.close() logger.info(f'TIMELOOP complete. t={float(self.t):.5f}, {tmax=:.5f}')
[docs] def set_reference_profiles(self, reference_profiles, last_ref_update_time=None): """ Initialise the model's reference profiles. reference_profiles (list): an iterable of pairs: (field_name, expr), where 'field_name' is the string giving the name of the reference profile field expr is the :class:`ufl.Expr` whose value is used to set the reference field. last_ref_update_time (float, optional): the last time that the reference profiles were updated. Defaults to None. """ for field_name, profile in reference_profiles: if field_name+'_bar' in self.fields: # For reference profiles already added to state, allow # interpolation from expressions ref = self.fields(field_name+'_bar') elif isinstance(profile, Function): # Need to add reference profile to state so profile must be # a Function ref = self.fields(field_name+'_bar', space=profile.function_space(), pick_up=True, dump=False, field_type='reference') else: raise ValueError(f'When initialising reference profile {field_name}' + ' the passed profile must be a Function') # if field name is not prognostic we need to add it ref.interpolate(profile) # Assign profile to X_ref belonging to equation if isinstance(self.equation, PrognosticEquationSet): if field_name in self.equation.field_names: idx = self.equation.field_names.index(field_name) X_ref = self.equation.X_ref.subfunctions[idx] X_ref.assign(ref) else: # reference profile of a diagnostic # warn user in case they made a typo logger.warning(f'Setting reference profile for diagnostic {field_name}') # Don't need to do anything else as value in field container has already been set self.reference_profiles_initialised = True self.last_ref_update_time = last_ref_update_time
[docs] class Timestepper(BaseTimestepper): """ Implements a timeloop by applying a scheme to a prognostic equation. """ def __init__(self, equation, scheme, io, spatial_methods=None, physics_parametrisations=None): """ Args: equation (:class:`PrognosticEquation`): the prognostic equation scheme (:class:`TimeDiscretisation`): the scheme to use to timestep the prognostic equation io (:class:`IO`): the model's object for controlling input/output. spatial_methods (iter, optional): a list of objects describing the methods to use for discretising transport or diffusion terms for each transported/diffused variable. Defaults to None, in which case the terms follow the original discretisation in the equation. physics_parametrisations: (iter, optional): an iterable of :class:`PhysicsParametrisation` objects that describe physical parametrisations to be included to add to the equation. They can only be used when the time discretisation `scheme` is explicit. Defaults to None. """ self.scheme = scheme if spatial_methods is not None: self.spatial_methods = spatial_methods else: self.spatial_methods = [] if physics_parametrisations is not None: self.physics_parametrisations = physics_parametrisations if len(self.physics_parametrisations) > 1: assert isinstance(scheme, ExplicitTimeDiscretisation), \ ('Physics parametrisations can only be used with the ' + 'basic TimeStepper when the time discretisation is ' + 'explicit. If you want to use an implicit scheme, the ' + 'SplitPhysicsTimestepper is more appropriate.') else: self.physics_parametrisations = [] super().__init__(equation=equation, io=io) @property def transporting_velocity(self): return "prognostic"
[docs] def setup_fields(self): self.x = TimeLevelFields(self.equation, self.scheme.nlevels) self.fields = StateFields(self.x, self.equation.prescribed_fields, *self.io.output.dumplist)
[docs] def setup_scheme(self): self.setup_equation(self.equation) self.scheme.setup(self.equation) self.setup_transporting_velocity(self.scheme) if self.io.output.log_courant: self.scheme.courant_max = self.io.courant_max
[docs] def timestep(self): """ Implement the timestep """ xnp1 = self.x.np1 name = self.equation.field_name x_in = [x(name) for x in self.x.previous[-self.scheme.nlevels:]] self.scheme.apply(xnp1(name), *x_in)
[docs] class PrescribedTransport(Timestepper): """ Implements a timeloop with a prescibed transporting velocity. """ def __init__(self, equation, scheme, io, prescribed_transporting_velocity, transport_method, physics_parametrisations=None): """ Args: equation (:class:`PrognosticEquation`): the prognostic equation scheme (:class:`TimeDiscretisation`): the scheme to use to timestep the prognostic equation. io (:class:`IO`): the model's object for controlling input/output. prescribed_transporting_velocity: (bool): Whether a time-varying transporting velocity will be defined. If True, this will require the transporting velocity to be setup by calling either the `setup_prescribed_expr` or `setup_prescribed_apply` methods. transport_method (:class:`TransportMethod`): describes the method used for discretising the transport term. physics_parametrisations: (iter, optional): an iterable of :class:`PhysicsParametrisation` objects that describe physical parametrisations to be included to add to the equation. They can only be used when the time discretisation `scheme` is explicit. Defaults to None. """ if isinstance(transport_method, TransportMethod): transport_methods = [transport_method] else: # Assume an iterable has been provided transport_methods = transport_method super().__init__(equation, scheme, io, spatial_methods=transport_methods, physics_parametrisations=physics_parametrisations) self.prescribed_transport_velocity = prescribed_transporting_velocity self.is_velocity_setup = not self.prescribed_transport_velocity self.velocity_projection = None self.velocity_apply = None @property def transporting_velocity(self): return self.fields('u')
[docs] def setup_fields(self): self.x = TimeLevelFields(self.equation, self.scheme.nlevels) self.fields = StateFields(self.x, self.equation.prescribed_fields, *self.io.output.dumplist)
[docs] def setup_prescribed_expr(self, expr_func): """ Sets up the prescribed transporting velocity, through a python function which has time as an argument, and returns a `ufl.Expr`. This allows the velocity to be updated with time. Args: expr_func (func): a python function with a single argument that represents the model time, and returns a `ufl.Expr`. """ if self.is_velocity_setup: raise RuntimeError('Prescribed velocity already set up!') project_params = { 'quadrature_degree': self.equation.domain.max_quad_degree } self.velocity_projection = Projector( expr_func(self.t), self.fields('u'), form_compiler_parameters=project_params ) self.is_velocity_setup = True
[docs] def setup_prescribed_apply(self, apply_method): """ Sets up the prescribed transporting velocity, through a python function which has time as an argument. This function will perform the evaluation of the transporting velocity. Args: expr_func (func): a python function with a single argument that represents the model time, and performs the evaluation of the transporting velocity. """ if self.is_velocity_setup: raise RuntimeError('Prescribed velocity already set up!') self.velocity_apply = apply_method self.is_velocity_setup = True
[docs] def run(self, t, tmax, pick_up=False): """ Runs the model for the specified time, from t to tmax Args: t (float): the start time of the run tmax (float): the end time of the run pick_up: (bool): specify whether to pick_up from a previous run """ # Throw an error if no transporting velocity has been set up if self.prescribed_transport_velocity and not self.is_velocity_setup: raise RuntimeError( 'A time-varying prescribed velocity is required. This must be ' + 'set up through calling either the setup_prescribed_expr or ' + 'setup_prescribed_apply routines.') # It's best to have evaluated the velocity before we start if self.velocity_projection is not None: self.velocity_projection.project() if self.velocity_apply is not None: self.velocity_apply(self.t) super().run(t, tmax, pick_up=pick_up)
[docs] def timestep(self): """ Implements the time step, which possibly involves evaluation of the prescribed transporting velocity. """ if self.velocity_projection is not None: self.velocity_projection.project() if self.velocity_apply is not None: self.velocity_apply(self.t) super().timestep()