Elliptic PDE Example¶
In [1]:
Copied!
import fastgps
import qmcpy as qp
import torch
torch.set_default_dtype(torch.float64)
import numpy as np
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve
from scipy.stats import norm
import itertools
import time
import pandas as pd
import matplotlib
from matplotlib import pyplot
pyplot.style.use("seaborn-v0_8-whitegrid")
COLORS = ["xkcd:"+color[:-1] for color in pd.read_csv("../../../xkcd_colors.txt",comment="#",header=None).iloc[:,0].tolist()][::-1]
pyplot.rcParams['axes.prop_cycle'] = matplotlib.cycler(color=COLORS)
LINESTYLES = ['solid','dotted','dashed','dashdot',(0, (1, 1))]
DEFAULTFONTSIZE = 30
pyplot.rcParams['xtick.labelsize'] = DEFAULTFONTSIZE
pyplot.rcParams['ytick.labelsize'] = DEFAULTFONTSIZE
pyplot.rcParams['ytick.labelsize'] = DEFAULTFONTSIZE
pyplot.rcParams['axes.titlesize'] = DEFAULTFONTSIZE
pyplot.rcParams['figure.titlesize'] = DEFAULTFONTSIZE
pyplot.rcParams["axes.labelsize"] = DEFAULTFONTSIZE
pyplot.rcParams['legend.fontsize'] = DEFAULTFONTSIZE
pyplot.rcParams['font.size'] = DEFAULTFONTSIZE
pyplot.rcParams['lines.linewidth'] = 5
pyplot.rcParams['lines.markersize'] = 15
PW = 30 # inches
import fastgps
import qmcpy as qp
import torch
torch.set_default_dtype(torch.float64)
import numpy as np
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve
from scipy.stats import norm
import itertools
import time
import pandas as pd
import matplotlib
from matplotlib import pyplot
pyplot.style.use("seaborn-v0_8-whitegrid")
COLORS = ["xkcd:"+color[:-1] for color in pd.read_csv("../../../xkcd_colors.txt",comment="#",header=None).iloc[:,0].tolist()][::-1]
pyplot.rcParams['axes.prop_cycle'] = matplotlib.cycler(color=COLORS)
LINESTYLES = ['solid','dotted','dashed','dashdot',(0, (1, 1))]
DEFAULTFONTSIZE = 30
pyplot.rcParams['xtick.labelsize'] = DEFAULTFONTSIZE
pyplot.rcParams['ytick.labelsize'] = DEFAULTFONTSIZE
pyplot.rcParams['ytick.labelsize'] = DEFAULTFONTSIZE
pyplot.rcParams['axes.titlesize'] = DEFAULTFONTSIZE
pyplot.rcParams['figure.titlesize'] = DEFAULTFONTSIZE
pyplot.rcParams["axes.labelsize"] = DEFAULTFONTSIZE
pyplot.rcParams['legend.fontsize'] = DEFAULTFONTSIZE
pyplot.rcParams['font.size'] = DEFAULTFONTSIZE
pyplot.rcParams['lines.linewidth'] = 5
pyplot.rcParams['lines.markersize'] = 15
PW = 30 # inches
In [32]:
Copied!
def _spsolve(main_diag, upper_diag, lower_diag, b):
N = main_diag.shape[-1]
assert main_diag.shape==(N,) and upper_diag.shape==(N-1,) and lower_diag.shape==(N-1,) and b.shape==(N,)
A = diags([main_diag, upper_diag, lower_diag],[0, 1, -1],shape=[N,N],format="csc",)
u = spsolve(A, b)
return u
vec_spsolve = np.vectorize(_spsolve,signature="(n),(m),(m),(n)->(n)")
def solve_elliptic_pde(level=5, coeffs=None):
# Define grid
N = 2 ** (level + 2) # Number of intervals
h = 1.0 / N # Mesh spacing
x = np.linspace(h, 1 - h, N - 1) # Interior points
# Compute diffusion coefficient a(x)
coeffs = np.random.rand(8) if coeffs is None else coeffs
assert isinstance(coeffs,np.ndarray)
coeffs = norm.ppf(coeffs) # Transform from uniform to iid Gaussian
batch_shape = list(coeffs.shape)[:-1]
k = np.arange(1,coeffs.shape[-1]+1)
a_x = np.exp((coeffs[...,None] / k[:,None] *np.sin(np.pi * k[:,None] * x)).sum(-2))
# Compute a at half-grid points (needed for flux terms)
a_half = np.zeros(batch_shape+[N])
a_half[...,1:-1] = (a_x[...,:-1] + a_x[...,1:]) / 2 # Midpoint values for flux approximation
a_half[...,0] = a_x[...,0] # At the first midpoint
a_half[...,-1] = a_x[...,-1] # At the last midpoint
# Construct the finite difference matrix
lower_diag = -a_half[...,1:-1] / h**2
upper_diag = -a_half[...,1:-1] / h**2
main_diag = (a_half[...,:-1] + a_half[...,1:]) / h**2
# Right-hand side (forcing term)
b = np.ones([N-1]) # Constant source term (1)
u = vec_spsolve(main_diag, upper_diag, lower_diag, b)
# Find index closest to x = 0.5
idx = np.argmin(np.abs(x - 0.5))
return u[...,idx], u, x, a_x, x[idx]
def elliptic(level, samples):
return solve_elliptic_pde(level,samples)[0]
levels = np.array([1,5],dtype=int)
num_levels = len(levels)
d = 16
costs = 2.**levels/2.**levels.max()
ntest = 2**10
xtest = qp.Halton(d,seed=7)(ntest)
qtest = np.zeros((num_levels,ntest))
fig,ax = pyplot.subplots(nrows=2,ncols=num_levels,figsize=(PW,PW/num_levels*2))
Qprev = 0
for l in range(num_levels):
Ql,ql,ul,al,u0l = solve_elliptic_pde(levels[l],xtest)
Yl = Ql-Qprev
Qprev = Ql
qtest[l] = Ql
print("level = %-5d Ql.mean() = %-10.2f Ql.std() = %-10.1e Yl.mean() = %-10.1e Yl.std() = %-10.1e Ql.shape = %-10s ql.shape = %-15s ul.shape = %-10s al.shape = %-15s u0l = %-10s"%\
(l,Ql.mean(),Ql.std(ddof=1),Yl.mean(),Yl.std(ddof=1),Ql.shape,ql.shape,ul.shape,ql.shape,u0l))
ax[0,l].plot(ul,ql.T)
ax[0,l].scatter(u0l*np.ones_like(Ql),Ql)
ax[1,l].plot(ul,al.T)
ax[0,0].set_ylabel("solutions")
ax[1,0].set_ylabel("forcing term")
print("xtest.shape = %s"%str(xtest.shape))
print("qtest.shape = %s"%str(qtest.shape))
print(costs)
def _spsolve(main_diag, upper_diag, lower_diag, b):
N = main_diag.shape[-1]
assert main_diag.shape==(N,) and upper_diag.shape==(N-1,) and lower_diag.shape==(N-1,) and b.shape==(N,)
A = diags([main_diag, upper_diag, lower_diag],[0, 1, -1],shape=[N,N],format="csc",)
u = spsolve(A, b)
return u
vec_spsolve = np.vectorize(_spsolve,signature="(n),(m),(m),(n)->(n)")
def solve_elliptic_pde(level=5, coeffs=None):
# Define grid
N = 2 ** (level + 2) # Number of intervals
h = 1.0 / N # Mesh spacing
x = np.linspace(h, 1 - h, N - 1) # Interior points
# Compute diffusion coefficient a(x)
coeffs = np.random.rand(8) if coeffs is None else coeffs
assert isinstance(coeffs,np.ndarray)
coeffs = norm.ppf(coeffs) # Transform from uniform to iid Gaussian
batch_shape = list(coeffs.shape)[:-1]
k = np.arange(1,coeffs.shape[-1]+1)
a_x = np.exp((coeffs[...,None] / k[:,None] *np.sin(np.pi * k[:,None] * x)).sum(-2))
# Compute a at half-grid points (needed for flux terms)
a_half = np.zeros(batch_shape+[N])
a_half[...,1:-1] = (a_x[...,:-1] + a_x[...,1:]) / 2 # Midpoint values for flux approximation
a_half[...,0] = a_x[...,0] # At the first midpoint
a_half[...,-1] = a_x[...,-1] # At the last midpoint
# Construct the finite difference matrix
lower_diag = -a_half[...,1:-1] / h**2
upper_diag = -a_half[...,1:-1] / h**2
main_diag = (a_half[...,:-1] + a_half[...,1:]) / h**2
# Right-hand side (forcing term)
b = np.ones([N-1]) # Constant source term (1)
u = vec_spsolve(main_diag, upper_diag, lower_diag, b)
# Find index closest to x = 0.5
idx = np.argmin(np.abs(x - 0.5))
return u[...,idx], u, x, a_x, x[idx]
def elliptic(level, samples):
return solve_elliptic_pde(level,samples)[0]
levels = np.array([1,5],dtype=int)
num_levels = len(levels)
d = 16
costs = 2.**levels/2.**levels.max()
ntest = 2**10
xtest = qp.Halton(d,seed=7)(ntest)
qtest = np.zeros((num_levels,ntest))
fig,ax = pyplot.subplots(nrows=2,ncols=num_levels,figsize=(PW,PW/num_levels*2))
Qprev = 0
for l in range(num_levels):
Ql,ql,ul,al,u0l = solve_elliptic_pde(levels[l],xtest)
Yl = Ql-Qprev
Qprev = Ql
qtest[l] = Ql
print("level = %-5d Ql.mean() = %-10.2f Ql.std() = %-10.1e Yl.mean() = %-10.1e Yl.std() = %-10.1e Ql.shape = %-10s ql.shape = %-15s ul.shape = %-10s al.shape = %-15s u0l = %-10s"%\
(l,Ql.mean(),Ql.std(ddof=1),Yl.mean(),Yl.std(ddof=1),Ql.shape,ql.shape,ul.shape,ql.shape,u0l))
ax[0,l].plot(ul,ql.T)
ax[0,l].scatter(u0l*np.ones_like(Ql),Ql)
ax[1,l].plot(ul,al.T)
ax[0,0].set_ylabel("solutions")
ax[1,0].set_ylabel("forcing term")
print("xtest.shape = %s"%str(xtest.shape))
print("qtest.shape = %s"%str(qtest.shape))
print(costs)
level = 0 Ql.mean() = 0.15 Ql.std() = 9.9e-02 Yl.mean() = 1.5e-01 Yl.std() = 9.9e-02 Ql.shape = (1024,) ql.shape = (1024, 7) ul.shape = (7,) al.shape = (1024, 7) u0l = 0.5 level = 1 Ql.mean() = 0.15 Ql.std() = 9.2e-02 Yl.mean() = 6.6e-03 Yl.std() = 1.9e-02 Ql.shape = (1024,) ql.shape = (1024, 127) ul.shape = (127,) al.shape = (1024, 127) u0l = 0.5 xtest.shape = (1024, 16) qtest.shape = (2, 1024) [0.0625 1. ]
In [40]:
Copied!
def run_gpr(name,n):
assert n.shape==(num_levels,)
if name=="standard multitask GP":
fgp = fastgps.StandardGP(
kernel = qp.KernelMultiTask(
base_kernel = qp.KernelSquaredExponential(
d = d,
torchify = True,
),
num_tasks = num_levels,
rank_factor = 1,
),
seqs = [qp.DigitalNetB2(dimension = d) for i in range(num_levels)],
)
elif name=="fast multitask GP - digital net":
fgp = fastgps.FastGPDigitalNetB2(
kernel = qp.KernelMultiTask(
base_kernel = qp.KernelDigShiftInvar(
d = d,
torchify = True,
),
num_tasks = num_levels,
rank_factor = 1,
),
seqs = [qp.DigitalNetB2(dimension = d, randomize="DS") for i in range(num_levels)],
)
elif name=="fast multitask GP - lattice":
fgp = fastgps.FastGPLattice(
kernel = qp.KernelMultiTask(
base_kernel = qp.KernelShiftInvar(
d = d,
torchify = True,
alpha = 3,
),
num_tasks = num_levels,
rank_factor = 1,
),
seqs = [qp.Lattice(dimension = d) for i in range(num_levels)],
)
else:
raise Exception("invalid name = %s"%name)
xnext = fgp.get_x_next(torch.from_numpy(n))
ynext = [torch.from_numpy(elliptic(levels[l],xnext_l.numpy())) for l,xnext_l in enumerate(xnext)]
fgp.add_y_next(ynext)
t0 = time.perf_counter()
fgp.fit(
loss_metric = "MLL",
iterations = 100,
#stop_crit_wait_iterations = np.inf,
#stop_crit_improvement_threshold = np.inf,
verbose = 0,
)
tdiff = time.perf_counter()-t0
yhat = fgp.post_mean(torch.from_numpy(xtest)).numpy()
l2rerrors = np.linalg.norm(yhat-qtest,axis=1)/np.linalg.norm(qtest,axis=-1)
return l2rerrors,tdiff
names = [
"standard multitask GP",
"fast multitask GP - digital net",
# "fast multitask GP - lattice",
]
mmin = 10
k = 1
trials = 1
# ns = (2**(mmin+np.stack([n for n in itertools.product(*[list(range(k)) for l in range(num_levels)])])))[::-1]
ns = np.array([
[2**8,2**6],
[2**9,2**6],
[2**10,2**6],
[2**11,2**6],
[2**12,2**6],
])
print("ns.shape = %s"%str(ns.shape))
print(ns)
print()
print("ns.shape = %s"%str(ns.shape))
print("ns[:5]")
print(ns[:5])
print()
verbose = 1
l2rerrors = np.nan*np.empty((len(names),len(ns),num_levels,trials))
times = np.nan*np.empty((len(names),len(ns),trials))
for i in range(len(ns)):
for j,name in enumerate(names):
for t in range(trials):
l2rerrors[j,i,:,t],times[j,i,t] = run_gpr(name,ns[i])
with np.printoptions(formatter={"float":lambda x: "%.1e"%x}):
print("%35s: t = %-5d i = %-5d n = %-15s times[j,i,t] = %-10.1e l2rerrors[j,i,:,t] = %s"%(name,t,i,ns[i],times[j,i,t],l2rerrors[j,i,:,t]))
print()
def run_gpr(name,n):
assert n.shape==(num_levels,)
if name=="standard multitask GP":
fgp = fastgps.StandardGP(
kernel = qp.KernelMultiTask(
base_kernel = qp.KernelSquaredExponential(
d = d,
torchify = True,
),
num_tasks = num_levels,
rank_factor = 1,
),
seqs = [qp.DigitalNetB2(dimension = d) for i in range(num_levels)],
)
elif name=="fast multitask GP - digital net":
fgp = fastgps.FastGPDigitalNetB2(
kernel = qp.KernelMultiTask(
base_kernel = qp.KernelDigShiftInvar(
d = d,
torchify = True,
),
num_tasks = num_levels,
rank_factor = 1,
),
seqs = [qp.DigitalNetB2(dimension = d, randomize="DS") for i in range(num_levels)],
)
elif name=="fast multitask GP - lattice":
fgp = fastgps.FastGPLattice(
kernel = qp.KernelMultiTask(
base_kernel = qp.KernelShiftInvar(
d = d,
torchify = True,
alpha = 3,
),
num_tasks = num_levels,
rank_factor = 1,
),
seqs = [qp.Lattice(dimension = d) for i in range(num_levels)],
)
else:
raise Exception("invalid name = %s"%name)
xnext = fgp.get_x_next(torch.from_numpy(n))
ynext = [torch.from_numpy(elliptic(levels[l],xnext_l.numpy())) for l,xnext_l in enumerate(xnext)]
fgp.add_y_next(ynext)
t0 = time.perf_counter()
fgp.fit(
loss_metric = "MLL",
iterations = 100,
#stop_crit_wait_iterations = np.inf,
#stop_crit_improvement_threshold = np.inf,
verbose = 0,
)
tdiff = time.perf_counter()-t0
yhat = fgp.post_mean(torch.from_numpy(xtest)).numpy()
l2rerrors = np.linalg.norm(yhat-qtest,axis=1)/np.linalg.norm(qtest,axis=-1)
return l2rerrors,tdiff
names = [
"standard multitask GP",
"fast multitask GP - digital net",
# "fast multitask GP - lattice",
]
mmin = 10
k = 1
trials = 1
# ns = (2**(mmin+np.stack([n for n in itertools.product(*[list(range(k)) for l in range(num_levels)])])))[::-1]
ns = np.array([
[2**8,2**6],
[2**9,2**6],
[2**10,2**6],
[2**11,2**6],
[2**12,2**6],
])
print("ns.shape = %s"%str(ns.shape))
print(ns)
print()
print("ns.shape = %s"%str(ns.shape))
print("ns[:5]")
print(ns[:5])
print()
verbose = 1
l2rerrors = np.nan*np.empty((len(names),len(ns),num_levels,trials))
times = np.nan*np.empty((len(names),len(ns),trials))
for i in range(len(ns)):
for j,name in enumerate(names):
for t in range(trials):
l2rerrors[j,i,:,t],times[j,i,t] = run_gpr(name,ns[i])
with np.printoptions(formatter={"float":lambda x: "%.1e"%x}):
print("%35s: t = %-5d i = %-5d n = %-15s times[j,i,t] = %-10.1e l2rerrors[j,i,:,t] = %s"%(name,t,i,ns[i],times[j,i,t],l2rerrors[j,i,:,t]))
print()
ns.shape = (5, 2)
[[ 256 64]
[ 512 64]
[1024 64]
[2048 64]
[4096 64]]
ns.shape = (5, 2)
ns[:5]
[[ 256 64]
[ 512 64]
[1024 64]
[2048 64]
[4096 64]]
standard multitask GP: t = 0 i = 0 n = [256 64] times[j,i,t] = 5.4e-01 l2rerrors[j,i,:,t] = [1.5e-01 2.5e-01]
fast multitask GP - digital net: t = 0 i = 0 n = [256 64] times[j,i,t] = 3.2e-01 l2rerrors[j,i,:,t] = [1.4e-01 1.4e-01]
standard multitask GP: t = 0 i = 1 n = [512 64] times[j,i,t] = 1.5e+00 l2rerrors[j,i,:,t] = [9.7e-02 1.7e-01]
fast multitask GP - digital net: t = 0 i = 1 n = [512 64] times[j,i,t] = 3.4e-01 l2rerrors[j,i,:,t] = [9.9e-02 1.3e-01]
standard multitask GP: t = 0 i = 2 n = [1024 64] times[j,i,t] = 4.3e+00 l2rerrors[j,i,:,t] = [1.4e-01 2.0e-01]
fast multitask GP - digital net: t = 0 i = 2 n = [1024 64] times[j,i,t] = 3.3e-01 l2rerrors[j,i,:,t] = [1.5e-01 1.2e-01]
standard multitask GP: t = 0 i = 3 n = [2048 64] times[j,i,t] = 2.0e+01 l2rerrors[j,i,:,t] = [8.2e-02 1.6e-01]
fast multitask GP - digital net: t = 0 i = 3 n = [2048 64] times[j,i,t] = 7.2e-01 l2rerrors[j,i,:,t] = [6.6e-02 9.8e-02]
standard multitask GP: t = 0 i = 4 n = [4096 64] times[j,i,t] = 8.5e+01 l2rerrors[j,i,:,t] = [6.7e-02 1.7e-01]
fast multitask GP - digital net: t = 0 i = 4 n = [4096 64] times[j,i,t] = 9.4e-01 l2rerrors[j,i,:,t] = [5.7e-02 1.2e-01]
In [44]:
Copied!
fig,axs = pyplot.subplots(nrows=1,ncols=2,figsize=(PW/2,PW/2/2))
ax = axs[0]
x_l0 = qp.DigitalNetB2(2,seed=7)(2**5)
x_l1 = qp.DigitalNet(2,seed=11)(2**3)
ax.scatter(x_l0[:,0],x_l0[:,1],label="source",s=250,color="red",marker="o")
ax.scatter(x_l1[:,0],x_l1[:,1],label="target",s=250,color="blue",marker="^")
# ax.set_xlim([0,1])
ax.set_xticks([0,1/4,1/2,3/4,1]); ax.set_xticklabels([r"$0$",r"$1/4$",r"$1/2$",r"$3/4$",r"$1$"])
# ax.set_ylim([0,1])
ax.set_yticks([0,1/4,1/2,3/4,1]); ax.set_yticklabels([r"$0$",r"$1/4$",r"$1/2$",r"$3/4$",r"$1$"])
fig.legend(frameon=False,bbox_to_anchor=(.47,1.1),ncol=3)
ax = axs[1]
for j in range(1,len(names)):
xtrend = np.median(times[j],axis=-1)/np.median(times[0],axis=-1)
ytrend = np.median(l2rerrors[j,:,-1],axis=-1)/np.median(l2rerrors[0,:,-1],axis=-1)
for k in range(len(ns)):
ax.scatter(xtrend[k],ytrend[k],color="k",zorder=10)
ax.annotate(r"$2^{%d}$"%int(np.log2(ns[k,0])),(xtrend[k]*.8,ytrend[k]+.015))
# ax.set_xscale("log",base=10)
# ax.set_yscale("log",base=10)
ax.set_xlabel("computation time ratio")
ax.set_ylabel("prediction error ratio")
ax.axhline(y=1,color='k',linestyle="dotted",linewidth=2)
ax.axvline(x=1,color='k',linestyle="dotted",linewidth=2)
# ax.set_xticks([1,3,5,7,9])
xmin,xmax = ax.get_xlim()
ymin,ymax = ax.get_ylim()
ax.set_ylim([ymin,ymax])
ax.set_xscale("log",base=10)
ax.set_xticks([.01,.1,1]); ax.set_xticklabels([.01,.1,1])
currxmin,currxmax = ax.get_xlim()
ax.set_xlim([currxmin,1.25])
currymin,currymax = ax.get_ylim()
# ax.set_ylim([currymin,1.025])
# ax.fill_between(np.array([1,xmax]),np.ones(1),ymax,color=COLORS[1],alpha=.25)
# ax.fill_between(np.array([xmin,1]),np.ones(1),ymax,color=COLORS[4],alpha=.25)
# ax.fill_between(np.array([1,xmax]),ymin,np.ones(1),color=COLORS[4],alpha=.25)
# ax.fill_between(np.array([xmin,1]),ymin,np.ones(1),color=COLORS[3],alpha=.25)
# fig.legend(frameon=False,bbox_to_anchor=(.9,1.15))
# fig.suptitle("Standard MTGP vs Proposed Fast MTGP")
fig.tight_layout()
fig.savefig("./elliptic_viz.pdf",bbox_inches="tight")
fig,axs = pyplot.subplots(nrows=1,ncols=2,figsize=(PW/2,PW/2/2))
ax = axs[0]
x_l0 = qp.DigitalNetB2(2,seed=7)(2**5)
x_l1 = qp.DigitalNet(2,seed=11)(2**3)
ax.scatter(x_l0[:,0],x_l0[:,1],label="source",s=250,color="red",marker="o")
ax.scatter(x_l1[:,0],x_l1[:,1],label="target",s=250,color="blue",marker="^")
# ax.set_xlim([0,1])
ax.set_xticks([0,1/4,1/2,3/4,1]); ax.set_xticklabels([r"$0$",r"$1/4$",r"$1/2$",r"$3/4$",r"$1$"])
# ax.set_ylim([0,1])
ax.set_yticks([0,1/4,1/2,3/4,1]); ax.set_yticklabels([r"$0$",r"$1/4$",r"$1/2$",r"$3/4$",r"$1$"])
fig.legend(frameon=False,bbox_to_anchor=(.47,1.1),ncol=3)
ax = axs[1]
for j in range(1,len(names)):
xtrend = np.median(times[j],axis=-1)/np.median(times[0],axis=-1)
ytrend = np.median(l2rerrors[j,:,-1],axis=-1)/np.median(l2rerrors[0,:,-1],axis=-1)
for k in range(len(ns)):
ax.scatter(xtrend[k],ytrend[k],color="k",zorder=10)
ax.annotate(r"$2^{%d}$"%int(np.log2(ns[k,0])),(xtrend[k]*.8,ytrend[k]+.015))
# ax.set_xscale("log",base=10)
# ax.set_yscale("log",base=10)
ax.set_xlabel("computation time ratio")
ax.set_ylabel("prediction error ratio")
ax.axhline(y=1,color='k',linestyle="dotted",linewidth=2)
ax.axvline(x=1,color='k',linestyle="dotted",linewidth=2)
# ax.set_xticks([1,3,5,7,9])
xmin,xmax = ax.get_xlim()
ymin,ymax = ax.get_ylim()
ax.set_ylim([ymin,ymax])
ax.set_xscale("log",base=10)
ax.set_xticks([.01,.1,1]); ax.set_xticklabels([.01,.1,1])
currxmin,currxmax = ax.get_xlim()
ax.set_xlim([currxmin,1.25])
currymin,currymax = ax.get_ylim()
# ax.set_ylim([currymin,1.025])
# ax.fill_between(np.array([1,xmax]),np.ones(1),ymax,color=COLORS[1],alpha=.25)
# ax.fill_between(np.array([xmin,1]),np.ones(1),ymax,color=COLORS[4],alpha=.25)
# ax.fill_between(np.array([1,xmax]),ymin,np.ones(1),color=COLORS[4],alpha=.25)
# ax.fill_between(np.array([xmin,1]),ymin,np.ones(1),color=COLORS[3],alpha=.25)
# fig.legend(frameon=False,bbox_to_anchor=(.9,1.15))
# fig.suptitle("Standard MTGP vs Proposed Fast MTGP")
fig.tight_layout()
fig.savefig("./elliptic_viz.pdf",bbox_inches="tight")
In [ ]:
Copied!