Fast Batch Multitask Lattice GP¶
In [1]:
Copied!
import fastgps
import torch
import numpy as np
import fastgps
import torch
import numpy as np
In [2]:
Copied!
torch.set_default_dtype(torch.float64)
torch.set_default_dtype(torch.float64)
True Function¶
In [3]:
Copied!
d = 6
rng = torch.Generator().manual_seed(7)
shape_batch = [2,3,4]
num_tasks = 5
def f(l, x):
consts = torch.arange(torch.prod(torch.tensor(shape_batch))).reshape(shape_batch)
y = (consts[...,None,None]*x**torch.arange(1,d+1)).sum(-1)+torch.randn(shape_batch+[x.size(0)],generator=rng)/(3+l)
return y
x = torch.rand((2**7,d),generator=rng) # random testing locations
y = torch.cat([f(l,x)[...,None,:] for l in range(num_tasks)],-2) # true values at random testing locations
z = torch.rand((2**6,d),generator=rng) # other random locations at which to evaluate covariance
print("x.shape = %s"%str(tuple(x.shape)))
print("y.shape = %s"%str(tuple(y.shape)))
print("z.shape = %s"%str(tuple(z.shape)))
d = 6
rng = torch.Generator().manual_seed(7)
shape_batch = [2,3,4]
num_tasks = 5
def f(l, x):
consts = torch.arange(torch.prod(torch.tensor(shape_batch))).reshape(shape_batch)
y = (consts[...,None,None]*x**torch.arange(1,d+1)).sum(-1)+torch.randn(shape_batch+[x.size(0)],generator=rng)/(3+l)
return y
x = torch.rand((2**7,d),generator=rng) # random testing locations
y = torch.cat([f(l,x)[...,None,:] for l in range(num_tasks)],-2) # true values at random testing locations
z = torch.rand((2**6,d),generator=rng) # other random locations at which to evaluate covariance
print("x.shape = %s"%str(tuple(x.shape)))
print("y.shape = %s"%str(tuple(y.shape)))
print("z.shape = %s"%str(tuple(z.shape)))
x.shape = (128, 6) y.shape = (2, 3, 4, 5, 128) z.shape = (64, 6)
Construct Fast GP¶
In [4]:
Copied!
fgp = fastgps.FastGPLattice(d,seed_for_seq=7,num_tasks=num_tasks,
shape_batch=shape_batch,
shape_scale = shape_batch[:]+[1],
shape_lengthscales = shape_batch[1:]+[d],
shape_noise = shape_batch[2:]+[1],
shape_factor_task_kernel = shape_batch[:]+[num_tasks,num_tasks],
shape_noise_task_kernel = shape_batch[1:]+[num_tasks]
)
print("fgp.scale.shape = %s"%str(tuple(fgp.scale.shape)))
print("fgp.lengthscales.shape = %s"%str(tuple(fgp.lengthscales.shape)))
print("fgp.noise.shape = %s"%str(tuple(fgp.noise.shape)))
print("fgp.factor_task_kernel.shape = %s"%str(tuple(fgp.factor_task_kernel.shape)))
print("fgp.noise_task_kernel.shape = %s"%str(tuple(fgp.noise_task_kernel.shape)))
fgp = fastgps.FastGPLattice(d,seed_for_seq=7,num_tasks=num_tasks,
shape_batch=shape_batch,
shape_scale = shape_batch[:]+[1],
shape_lengthscales = shape_batch[1:]+[d],
shape_noise = shape_batch[2:]+[1],
shape_factor_task_kernel = shape_batch[:]+[num_tasks,num_tasks],
shape_noise_task_kernel = shape_batch[1:]+[num_tasks]
)
print("fgp.scale.shape = %s"%str(tuple(fgp.scale.shape)))
print("fgp.lengthscales.shape = %s"%str(tuple(fgp.lengthscales.shape)))
print("fgp.noise.shape = %s"%str(tuple(fgp.noise.shape)))
print("fgp.factor_task_kernel.shape = %s"%str(tuple(fgp.factor_task_kernel.shape)))
print("fgp.noise_task_kernel.shape = %s"%str(tuple(fgp.noise_task_kernel.shape)))
fgp.scale.shape = (2, 3, 4, 1) fgp.lengthscales.shape = (3, 4, 6) fgp.noise.shape = (4, 1) fgp.factor_task_kernel.shape = (2, 3, 4, 5, 5) fgp.noise_task_kernel.shape = (3, 4, 5)
In [5]:
Copied!
x_next = fgp.get_x_next(n=2**torch.arange(num_tasks+1,1,-1))
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
fgp.add_y_next(y_next)
for i in range(len(x_next)):
print("i = %d"%i)
print("\tx_next[%d].shape = %s"%(i,str(tuple(x_next[i].shape))))
print("\ty_next[%d].shape = %s"%(i,str(tuple(y_next[i].shape))))
x_next = fgp.get_x_next(n=2**torch.arange(num_tasks+1,1,-1))
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
fgp.add_y_next(y_next)
for i in range(len(x_next)):
print("i = %d"%i)
print("\tx_next[%d].shape = %s"%(i,str(tuple(x_next[i].shape))))
print("\ty_next[%d].shape = %s"%(i,str(tuple(y_next[i].shape))))
i = 0 x_next[0].shape = (64, 6) y_next[0].shape = (2, 3, 4, 64) i = 1 x_next[1].shape = (32, 6) y_next[1].shape = (2, 3, 4, 32) i = 2 x_next[2].shape = (16, 6) y_next[2].shape = (2, 3, 4, 16) i = 3 x_next[3].shape = (8, 6) y_next[3].shape = (2, 3, 4, 8) i = 4 x_next[4].shape = (4, 6) y_next[4].shape = (2, 3, 4, 4)
In [6]:
Copied!
pmean = fgp.post_mean(x)
print("pmean.shape = %s"%str(tuple(pmean.shape)))
l2rerror = torch.linalg.norm(y-pmean,dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
pmean = fgp.post_mean(x)
print("pmean.shape = %s"%str(tuple(pmean.shape)))
l2rerror = torch.linalg.norm(y-pmean,dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
pmean.shape = (2, 3, 4, 5, 128) l2rerror.shape = (2, 3, 4, 5)
In [7]:
Copied!
data = fgp.fit(stop_crit_improvement_threshold=1e3)
list(data.keys())
data = fgp.fit(stop_crit_improvement_threshold=1e3)
list(data.keys())
iter of 5.0e+03 | loss | term1 | term2 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 0.00e+00 | 1.58e+04 | 2.71e+02 | 2.58e+04 5.00e+00 | 1.22e+04 | 5.22e+03 | 1.36e+04 1.00e+01 | 1.00e+04 | 3.58e+03 | 1.10e+04 1.50e+01 | 9.29e+03 | 3.54e+03 | 9.56e+03
2.00e+01 | 8.91e+03 | 3.39e+03 | 8.95e+03 2.50e+01 | 8.64e+03 | 3.31e+03 | 8.50e+03 3.00e+01 | 8.45e+03 | 3.30e+03 | 8.14e+03 3.50e+01 | 8.37e+03 | 3.17e+03 | 8.10e+03
4.00e+01 | 8.31e+03 | 3.06e+03 | 8.10e+03 4.50e+01 | 8.28e+03 | 3.09e+03 | 8.00e+03 5.00e+01 | 8.25e+03 | 3.08e+03 | 7.95e+03 5.50e+01 | 8.24e+03 | 3.07e+03 | 7.93e+03
6.00e+01 | 8.22e+03 | 3.07e+03 | 7.90e+03 6.50e+01 | 8.20e+03 | 3.03e+03 | 7.90e+03 7.00e+01 | 8.19e+03 | 3.07e+03 | 7.85e+03 7.50e+01 | 8.18e+03 | 3.06e+03 | 7.84e+03
8.00e+01 | 8.18e+03 | 3.04e+03 | 7.85e+03 8.50e+01 | 8.17e+03 | 3.05e+03 | 7.82e+03 9.00e+01 | 8.17e+03 | 3.02e+03 | 7.84e+03 9.50e+01 | 8.16e+03 | 3.00e+03 | 7.86e+03
1.00e+02 | 8.16e+03 | 3.00e+03 | 7.85e+03 1.05e+02 | 8.16e+03 | 3.01e+03 | 7.84e+03 1.08e+02 | 8.16e+03 | 3.00e+03 | 7.84e+03
Out[7]:
['iterations']
In [8]:
Copied!
pmean,pvar,q,ci_low,ci_high = fgp.post_ci(x,confidence=0.99)
print("pmean.shape = %s"%str(tuple(pmean.shape)))
print("pvar.shape = %s"%str(tuple(pvar.shape)))
print("q = %.2f"%q)
print("ci_low.shape = %s"%str(tuple(ci_low.shape)))
print("ci_high.shape = %s"%str(tuple(ci_high.shape)))
l2rerror = torch.linalg.norm(y-pmean,dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
pcov = fgp.post_cov(x,x)
print("pcov.shape = %s"%str(tuple(pcov.shape)))
_range0,_rangen1 = torch.arange(pcov.size(-3)),torch.arange(pcov.size(-1))
assert torch.allclose(pcov[...,_range0,_range0,:,:][...,_rangen1,_rangen1],pvar) and (pvar>=0).all()
pcov2 = fgp.post_cov(x,z)
print("pcov2.shape = %s"%str(tuple(pcov2.shape)))
pmean,pvar,q,ci_low,ci_high = fgp.post_ci(x,confidence=0.99)
print("pmean.shape = %s"%str(tuple(pmean.shape)))
print("pvar.shape = %s"%str(tuple(pvar.shape)))
print("q = %.2f"%q)
print("ci_low.shape = %s"%str(tuple(ci_low.shape)))
print("ci_high.shape = %s"%str(tuple(ci_high.shape)))
l2rerror = torch.linalg.norm(y-pmean,dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
pcov = fgp.post_cov(x,x)
print("pcov.shape = %s"%str(tuple(pcov.shape)))
_range0,_rangen1 = torch.arange(pcov.size(-3)),torch.arange(pcov.size(-1))
assert torch.allclose(pcov[...,_range0,_range0,:,:][...,_rangen1,_rangen1],pvar) and (pvar>=0).all()
pcov2 = fgp.post_cov(x,z)
print("pcov2.shape = %s"%str(tuple(pcov2.shape)))
pmean.shape = (2, 3, 4, 5, 128) pvar.shape = (2, 3, 4, 5, 128) q = 2.58 ci_low.shape = (2, 3, 4, 5, 128) ci_high.shape = (2, 3, 4, 5, 128) l2rerror.shape = (2, 3, 4, 5)
pcov.shape = (2, 3, 4, 5, 5, 128, 128)
pcov2.shape = (2, 3, 4, 5, 5, 128, 64)
In [9]:
Copied!
pcmean,pcvar,q,cci_low,cci_high = fgp.post_cubature_ci(confidence=0.99)
print("pcmean.shape = %s"%str(tuple(pcmean.shape)))
print("pcvar.shape = %s"%str(tuple(pcvar.shape)))
print("cci_low.shape = %s"%str(tuple(cci_low.shape)))
print("cci_high.shape = %s"%str(tuple(cci_high.shape)))
pccov = fgp.post_cubature_cov()
print("pccov.shape = %s"%str(tuple(pccov.shape)))
pcmean,pcvar,q,cci_low,cci_high = fgp.post_cubature_ci(confidence=0.99)
print("pcmean.shape = %s"%str(tuple(pcmean.shape)))
print("pcvar.shape = %s"%str(tuple(pcvar.shape)))
print("cci_low.shape = %s"%str(tuple(cci_low.shape)))
print("cci_high.shape = %s"%str(tuple(cci_high.shape)))
pccov = fgp.post_cubature_cov()
print("pccov.shape = %s"%str(tuple(pccov.shape)))
pcmean.shape = (2, 3, 4, 5) pcvar.shape = (2, 3, 4, 5) cci_low.shape = (2, 3, 4, 5) cci_high.shape = (2, 3, 4, 5) pccov.shape = (2, 3, 4, 5, 5)
Project and Increase Sample Size¶
In [10]:
Copied!
n_new = fgp.n*2
pcov_future = fgp.post_cov(x,z,n=n_new)
pvar_future = fgp.post_var(x,n=n_new)
pcvar_future = fgp.post_cubature_var(n=n_new)
n_new = fgp.n*2
pcov_future = fgp.post_cov(x,z,n=n_new)
pvar_future = fgp.post_var(x,n=n_new)
pcvar_future = fgp.post_cubature_var(n=n_new)
In [11]:
Copied!
x_next = fgp.get_x_next(n_new)
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
for _y in y_next:
print(_y.shape)
fgp.add_y_next(y_next)
l2rerror = torch.linalg.norm(y-fgp.post_mean(x),dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
assert torch.allclose(fgp.post_cov(x,z),pcov_future)
assert torch.allclose(fgp.post_var(x),pvar_future)
assert torch.allclose(fgp.post_cubature_var(),pcvar_future)
x_next = fgp.get_x_next(n_new)
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
for _y in y_next:
print(_y.shape)
fgp.add_y_next(y_next)
l2rerror = torch.linalg.norm(y-fgp.post_mean(x),dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
assert torch.allclose(fgp.post_cov(x,z),pcov_future)
assert torch.allclose(fgp.post_var(x),pvar_future)
assert torch.allclose(fgp.post_cubature_var(),pcvar_future)
torch.Size([2, 3, 4, 64]) torch.Size([2, 3, 4, 32]) torch.Size([2, 3, 4, 16]) torch.Size([2, 3, 4, 8]) torch.Size([2, 3, 4, 4]) l2rerror.shape = (2, 3, 4, 5)
In [12]:
Copied!
data = fgp.fit(iterations=5,verbose=False)
l2rerror = torch.linalg.norm(y-fgp.post_mean(x),dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
data = fgp.fit(iterations=5,verbose=False)
l2rerror = torch.linalg.norm(y-fgp.post_mean(x),dim=-1)/torch.linalg.norm(y,dim=-1)
print("l2rerror.shape = %s"%str(tuple(l2rerror.shape)))
l2rerror.shape = (2, 3, 4, 5)
In [13]:
Copied!
n_new = fgp.n*2
pcov_new = fgp.post_cov(x,z,n=n_new)
pvar_new = fgp.post_var(x,n=n_new)
pcvar_new = fgp.post_cubature_var(n=n_new)
x_next = fgp.get_x_next(n_new)
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
fgp.add_y_next(y_next)
assert torch.allclose(fgp.post_cov(x,z),pcov_new)
assert torch.allclose(fgp.post_var(x),pvar_new)
assert torch.allclose(fgp.post_cubature_var(),pcvar_new)
n_new = fgp.n*2
pcov_new = fgp.post_cov(x,z,n=n_new)
pvar_new = fgp.post_var(x,n=n_new)
pcvar_new = fgp.post_cubature_var(n=n_new)
x_next = fgp.get_x_next(n_new)
y_next = [f(l,x_next[l]) for l in range(num_tasks)]
fgp.add_y_next(y_next)
assert torch.allclose(fgp.post_cov(x,z),pcov_new)
assert torch.allclose(fgp.post_var(x),pvar_new)
assert torch.allclose(fgp.post_cubature_var(),pcvar_new)