from firedrake import (SpatialCoordinate, TrialFunction,
TestFunction, Function, DirichletBC,
LinearVariationalProblem, LinearVariationalSolver,
FunctionSpace, lhs, rhs, inner, div, dx, grad, dot,
as_vector, as_matrix, dS_h, dS_v, Constant, avg,
sqrt, jump, FacetNormal)
from gusto import thermodynamics
from gusto.diagnostics import DiagnosticField, Energy

[docs]class KineticEnergyY(Energy):
name = "KineticEnergyY"

[docs]    def compute(self, state):
"""
Computes the kinetic energy of the y component
"""
u = state.fields("u")
energy = self.kinetic(u[1])
return self.field.interpolate(energy)

[docs]class CompressibleKineticEnergyY(Energy):
name = "CompressibleKineticEnergyY"

[docs]    def compute(self, state):
"""
Computes the kinetic energy of the y component
"""
u = state.fields("u")
rho = state.fields("rho")
energy = self.kinetic(u[1], rho)
return self.field.interpolate(energy)

[docs]    def compute(self, state):
x, y, z = SpatialCoordinate(state.mesh)
b = state.fields("b")
bbar = state.fields("bbar")
H = state.parameters.H
potential = -(z-H/2)*(b-bbar)
return self.field.interpolate(potential)

[docs]    def compute(self, state):
x, y, z = SpatialCoordinate(state.mesh)
g = state.parameters.g
cp = state.parameters.cp
cv = state.parameters.cv
Pi0 = state.parameters.Pi0

rho = state.fields("rho")
theta = state.fields("theta")
Pi = thermodynamics.pi(state.parameters, rho, theta)

potential = rho*(g*z + cv*Pi*theta - cp*Pi0*theta)
return self.field.interpolate(potential)

[docs]class GeostrophicImbalance(DiagnosticField):
name = "GeostrophicImbalance"

[docs]    def setup(self, state):
super(GeostrophicImbalance, self).setup(state)
u = state.fields("u")
b = state.fields("b")
p = state.fields("p")
f = state.parameters.f
Vu = u.function_space()

v = TrialFunction(Vu)
w = TestFunction(Vu)
a = inner(w, v)*dx
L = (div(w)*p+inner(w, as_vector([f*u[1], 0.0, b])))*dx

bcs = [DirichletBC(Vu, 0.0, "bottom"),
DirichletBC(Vu, 0.0, "top")]

self.imbalance = Function(Vu)
imbalanceproblem = LinearVariationalProblem(a, L, self.imbalance, bcs=bcs)
self.imbalance_solver = LinearVariationalSolver(
imbalanceproblem, solver_parameters={'ksp_type': 'cg'})

[docs]    def compute(self, state):
f = state.parameters.f
self.imbalance_solver.solve()
geostrophic_imbalance = self.imbalance[0]/f
return self.field.interpolate(geostrophic_imbalance)

[docs]class TrueResidualV(DiagnosticField):
name = "TrueResidualV"

[docs]    def setup(self, state):
super(TrueResidualV, self).setup(state)
unew, pnew, bnew = state.xn.split()
uold, pold, bold = state.xb.split()
ubar = 0.5*(unew+uold)
H = state.parameters.H
f = state.parameters.f
dbdy = state.parameters.dbdy
dt = state.timestepping.dt
x, y, z = SpatialCoordinate(state.mesh)
V = FunctionSpace(state.mesh, "DG", 0)

wv = TestFunction(V)
v = TrialFunction(V)
vlhs = wv*v*dx
vrhs = wv*((unew[1]-uold[1])/dt + ubar[0]*ubar[1].dx(0)
+ ubar[2]*ubar[1].dx(2)
+ f*ubar[0] + dbdy*(z-H/2))*dx
self.vtres = Function(V)
vtresproblem = LinearVariationalProblem(vlhs, vrhs, self.vtres)
self.v_residual_solver = LinearVariationalSolver(
vtresproblem, solver_parameters={'ksp_type': 'cg'})

[docs]    def compute(self, state):
self.v_residual_solver.solve()
v_residual = self.vtres
return self.field.interpolate(v_residual)

[docs]class SawyerEliassenU(DiagnosticField):
name = "SawyerEliassenU"

[docs]    def setup(self, state):

space = state.spaces("HDiv")
super(SawyerEliassenU, self).setup(state, space=space)

u = state.fields("u")
b = state.fields("b")
v = inner(u, as_vector([0., 1., 0.]))

# spaces
V0 = FunctionSpace(state.mesh, "CG", 2)
Vu = u.function_space()

# project b to V0
self.b_v0 = Function(V0)
btri = TrialFunction(V0)
btes = TestFunction(V0)
a = inner(btes, btri) * dx
L = inner(btes, b) * dx
projectbproblem = LinearVariationalProblem(a, L, self.b_v0)
self.project_b_solver = LinearVariationalSolver(
projectbproblem, solver_parameters={'ksp_type': 'cg'})

# project v to V0
self.v_v0 = Function(V0)
vtri = TrialFunction(V0)
vtes = TestFunction(V0)
a = inner(vtes, vtri) * dx
L = inner(vtes, v) * dx
projectvproblem = LinearVariationalProblem(a, L, self.v_v0)
self.project_v_solver = LinearVariationalSolver(
projectvproblem, solver_parameters={'ksp_type': 'cg'})

# stm/psi is a stream function
self.stm = Function(V0)
psi = TrialFunction(V0)
xsi = TestFunction(V0)

f = state.parameters.f
H = state.parameters.H
L = state.parameters.L
dbdy = state.parameters.dbdy
x, y, z = SpatialCoordinate(state.mesh)

bcs = [DirichletBC(V0, 0., "bottom"),
DirichletBC(V0, 0., "top")]

Mat = as_matrix([[b.dx(2), 0., -f*self.v_v0.dx(2)],
[0., 0., 0.],
[-self.b_v0.dx(0), 0., f**2+f*self.v_v0.dx(0)]])

Equ = (
)*dx

# fourth-order terms
if state.parameters.fourthorder:
eps = Constant(0.0001)
brennersigma = Constant(10.0)
n = FacetNormal(state.mesh)
deltax = Constant(state.parameters.deltax)
deltaz = Constant(state.parameters.deltaz)

nn = as_matrix([[sqrt(brennersigma/Constant(deltax)), 0., 0.],
[0., 0., 0.],
[0., 0., sqrt(brennersigma/Constant(deltaz))]])

mu = as_matrix([[1., 0., 0.],
[0., 0., 0.],
[0., 0., H/L]])

# anisotropic form
Equ += eps*(
- (
)*(dS_h + dS_v)
)

Au = lhs(Equ)
Lu = rhs(Equ)
stmproblem = LinearVariationalProblem(Au, Lu, self.stm, bcs=bcs)
self.stream_function_solver = LinearVariationalSolver(
stmproblem, solver_parameters={'ksp_type': 'cg'})

# solve for sawyer_eliassen u
self.u = Function(Vu)
utrial = TrialFunction(Vu)
w = TestFunction(Vu)
a = inner(w, utrial)*dx
L = (w[0]*(-self.stm.dx(2))+w[2]*(self.stm.dx(0)))*dx
ugproblem = LinearVariationalProblem(a, L, self.u)
self.sawyer_eliassen_u_solver = LinearVariationalSolver(
ugproblem, solver_parameters={'ksp_type': 'cg'})

[docs]    def compute(self, state):
self.project_b_solver.solve()
self.project_v_solver.solve()
self.stream_function_solver.solve()
self.sawyer_eliassen_u_solver.solve()
sawyer_eliassen_u = self.u
return self.field.project(sawyer_eliassen_u)