5656from gpytorch .means import ConstantMean
5757
5858EXPECTED_KEYS = [
59- "latent_features" ,
6059 "mean_module.raw_constant" ,
6160 "covar_module.kernels.1.raw_var" ,
6261 "covar_module.kernels.1.active_dims" ,
@@ -112,7 +111,7 @@ def _get_data_and_model(
112111 )
113112 return train_X , train_Y , train_Yvar , model
114113
115- def _get_unnormalized_data (self , ** tkwargs ):
114+ def _get_unnormalized_data (self , infer_noise : bool = False , ** tkwargs ):
116115 with torch .random .fork_rng ():
117116 torch .manual_seed (0 )
118117 train_X = torch .rand (10 , 4 , ** tkwargs )
@@ -122,9 +121,28 @@ def _get_unnormalized_data(self, **tkwargs):
122121 )
123122 train_X = torch .cat ([5 + 5 * train_X , task_indices ], dim = 1 )
124123 test_X = 5 + 5 * torch .rand (5 , 4 , ** tkwargs )
125- train_Yvar = 0.1 * torch .arange (10 , ** tkwargs ).unsqueeze (- 1 )
124+ if infer_noise :
125+ train_Yvar = None
126+ else :
127+ train_Yvar = 0.1 * torch .arange (10 , ** tkwargs ).unsqueeze (- 1 )
126128 return train_X , train_Y , train_Yvar , test_X
127129
130+ def _get_unnormalized_condition_data (
131+ self , num_models : int , num_cond : int , dim : int , infer_noise : bool , ** tkwargs
132+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor | None ]:
133+ with torch .random .fork_rng ():
134+ torch .manual_seed (0 )
135+ cond_X = 5 + 5 * torch .rand (num_models , num_cond , dim , ** tkwargs )
136+ cond_Y = 10 + torch .sin (cond_X [..., :1 ])
137+ cond_Yvar = (
138+ None if infer_noise else 0.1 * torch .ones (cond_Y .shape , ** tkwargs )
139+ )
140+ # adding the task dimension
141+ cond_X = torch .cat (
142+ [cond_X , torch .zeros (num_models , num_cond , 1 , ** tkwargs )], dim = - 1
143+ )
144+ return cond_X , cond_Y , cond_Yvar
145+
128146 def _get_mcmc_samples (self , num_samples : int , dim : int , task_rank : int , ** tkwargs ):
129147 mcmc_samples = {
130148 "lengthscale" : torch .rand (num_samples , 1 , dim , ** tkwargs ),
@@ -604,6 +622,110 @@ def test_acquisition_functions(self):
604622 )
605623 self .assertEqual (acqf (test_X ).shape , torch .Size (batch_shape ))
606624
625+ def test_condition_on_observation (self ) -> None :
626+ # The following conditioned data shapes should work (output describes):
627+ # training data shape after cond(batch shape in output is req. in gpytorch)
628+ # X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
629+ # X: n x d, Y: n x d --> num_models x n x d
630+ # X: n x d, Y: num_models x n x d --> num_models x n x d
631+ num_models = 3
632+ num_cond = 2
633+ task_rank = 2
634+ for infer_noise , dtype in itertools .product (
635+ (True , False ), (torch .float , torch .double )
636+ ):
637+ tkwargs = {"device" : self .device , "dtype" : dtype }
638+ train_X , _ , _ , model = self ._get_data_and_model (
639+ task_rank = task_rank ,
640+ infer_noise = infer_noise ,
641+ ** tkwargs ,
642+ )
643+ num_dims = train_X .shape [1 ] - 1
644+ mcmc_samples = self ._get_mcmc_samples (
645+ num_samples = 3 ,
646+ dim = num_dims ,
647+ task_rank = task_rank ,
648+ ** tkwargs ,
649+ )
650+ model .load_mcmc_samples (mcmc_samples )
651+
652+ num_train = train_X .shape [0 ]
653+ test_X = torch .rand (num_models , num_dims , ** tkwargs )
654+
655+ cond_X , cond_Y , cond_Yvar = self ._get_unnormalized_condition_data (
656+ num_models = num_models ,
657+ num_cond = num_cond ,
658+ infer_noise = infer_noise ,
659+ dim = num_dims ,
660+ ** tkwargs ,
661+ )
662+
663+ # need to forward pass before conditioning
664+ model .posterior (train_X )
665+ cond_model = model .condition_on_observations (
666+ cond_X , cond_Y , noise = cond_Yvar
667+ )
668+ posterior = cond_model .posterior (test_X )
669+ self .assertEqual (
670+ posterior .mean .shape , torch .Size ([num_models , len (test_X ), 2 ])
671+ )
672+
673+ # since the data is not equal for the conditioned points, a batch size
674+ # is added to the training data
675+ self .assertEqual (
676+ cond_model .train_inputs [0 ].shape ,
677+ torch .Size ([num_models , num_train + num_cond , num_dims + 1 ]),
678+ )
679+
680+ # the batch shape of the condition model is added during conditioning
681+ self .assertEqual (cond_model .batch_shape , torch .Size ([num_models ]))
682+
683+ # condition on identical sets of data (i.e. one set) for all models
684+ # i.e, with no batch shape. This infers the batch shape.
685+ cond_X_nobatch , cond_Y_nobatch = cond_X [0 ], cond_Y [0 ]
686+
687+ # conditioning without a batch size - the resulting conditioned model
688+ # will still have a batch size
689+ model .posterior (train_X )
690+ cond_model = model .condition_on_observations (
691+ cond_X_nobatch , cond_Y_nobatch , noise = cond_Yvar
692+ )
693+ self .assertEqual (
694+ cond_model .train_inputs [0 ].shape ,
695+ torch .Size ([num_models , num_train + num_cond , num_dims + 1 ]),
696+ )
697+
698+ # With batch size only on Y.
699+ cond_model = model .condition_on_observations (
700+ cond_X_nobatch , cond_Y , noise = cond_Yvar
701+ )
702+ self .assertEqual (
703+ cond_model .train_inputs [0 ].shape ,
704+ torch .Size ([num_models , num_train + num_cond , num_dims + 1 ]),
705+ )
706+
707+ # test repeated conditioning
708+ repeat_cond_X = cond_X .clone ()
709+ repeat_cond_X [..., 0 :- 1 ] += 2
710+ repeat_cond_model = cond_model .condition_on_observations (
711+ repeat_cond_X , cond_Y , noise = cond_Yvar
712+ )
713+ self .assertEqual (
714+ repeat_cond_model .train_inputs [0 ].shape ,
715+ torch .Size ([num_models , num_train + 2 * num_cond , num_dims + 1 ]),
716+ )
717+
718+ # test repeated conditioning without a batch size
719+ repeat_cond_X_nobatch = cond_X_nobatch .clone ()
720+ repeat_cond_X_nobatch [..., 0 :- 1 ] += 2
721+ repeat_cond_model2 = repeat_cond_model .condition_on_observations (
722+ repeat_cond_X_nobatch , cond_Y_nobatch , noise = cond_Yvar
723+ )
724+ self .assertEqual (
725+ repeat_cond_model2 .train_inputs [0 ].shape ,
726+ torch .Size ([num_models , num_train + 3 * num_cond , num_dims + 1 ]),
727+ )
728+
607729 def test_load_samples (self ):
608730 for task_rank , dtype , use_outcome_transform in itertools .product (
609731 [1 , 2 ], [torch .float , torch .double ], (False , True )
@@ -671,18 +793,6 @@ def test_load_samples(self):
671793 train_Yvar_tf .clamp (MIN_INFERRED_NOISE_LEVEL ),
672794 )
673795 )
674- self .assertTrue (
675- torch .allclose (
676- model .task_covar_module .lengthscale ,
677- mcmc_samples ["task_lengthscale" ],
678- )
679- )
680- self .assertTrue (
681- torch .allclose (
682- model .latent_features ,
683- mcmc_samples ["latent_features" ],
684- )
685- )
686796
687797 def test_construct_inputs (self ):
688798 for dtype , infer_noise in [(torch .float , False ), (torch .double , True )]:
0 commit comments