Skip to content

Commit

Permalink
integrated evidence fuser to calculate likelihoods
Browse files Browse the repository at this point in the history
  • Loading branch information
sreekaroo committed Oct 23, 2023
1 parent 251cf43 commit bcd3a2c
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 18 deletions.
2 changes: 1 addition & 1 deletion bcipy/simulator/helpers/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self, threshold: float):
self.tau = threshold

def decide(self, series: List[InquiryResult]):
current_distribution = series[-1].evidence_likelihoods
current_distribution = series[-1].fused_likelihood
if np.max(current_distribution) > self.tau:
log.info("Committing to decision: posterior exceeded threshold.")
return True
Expand Down
45 changes: 45 additions & 0 deletions bcipy/simulator/helpers/evidence_fuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from abc import abstractmethod, ABC
from typing import Optional, Dict

import numpy as np


class EvidenceFuser(ABC):

@abstractmethod
def fuse(self, prior_likelihood: Optional[np.ndarray], evidence: Dict) -> np.ndarray:
...


class MultipyFuser(EvidenceFuser):

def __init__(self):
pass

def fuse(self, prior_likelihood, evidence) -> np.ndarray:

len_dist = len(list(evidence.values())[0])
prior_likelihood = prior_likelihood if prior_likelihood is not None else self.__make_prior(len_dist)
ret_likelihood = prior_likelihood.copy()

for value in evidence.values():
ret_likelihood *= value[:]
ret_likelihood = self.__clean_likelihood(ret_likelihood)

return ret_likelihood

def __make_prior(self, len_dist):
return np.ones(len_dist) / len_dist

def __clean_likelihood(self, likelihood):

cleaned_likelihood = likelihood.copy()
if np.isinf(np.sum(likelihood)):
tmp = np.zeros(len(likelihood))
tmp[np.where(likelihood == np.inf)[0][0]] = 1
cleaned_likelihood = tmp

if not np.isnan(np.sum(likelihood)):
cleaned_likelihood = likelihood / np.sum(likelihood)

return cleaned_likelihood
41 changes: 27 additions & 14 deletions bcipy/simulator/helpers/sim_state.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import copy
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Dict

import numpy as np

from bcipy.helpers.exceptions import FieldException
from bcipy.helpers.parameters import Parameters
from bcipy.helpers.symbols import alphabet
from bcipy.simulator.helpers.decision import SimDecisionCriteria, MaxIterationsSim, ProbThresholdSim
from bcipy.simulator.helpers.evidence_fuser import MultipyFuser, EvidenceFuser
from bcipy.simulator.helpers.types import InquiryResult
from bcipy.task.control.criteria import DecisionCriteria, MaxIterationsCriteria, ProbThresholdCriteria
from bcipy.task.control.handler import EvidenceFusion


@dataclass
Expand All @@ -23,7 +25,7 @@ class SimState:
inquiry_n: int
series_n: int
series_results: List[List[InquiryResult]]

# TODO store the fused results
decision_criterion: List[SimDecisionCriteria]

def total_inquiry_count(self):
Expand All @@ -36,7 +38,7 @@ def total_inquiry_count(self):

class StateManager(ABC):

def update(self, evidence: np.ndarray):
def update(self, evidence: np.ndarray): # TODO change evidence type to dictionary or some dataclass
raise NotImplementedError()

def is_done(self) -> bool:
Expand All @@ -51,32 +53,40 @@ def mutate_state(self, state_field, state_value):

class StateManagerImpl(StateManager):

def __init__(self, parameters: Parameters):
def __init__(self, parameters: Parameters, fuser_class=MultipyFuser):
self.state: SimState = self.initial_state()
self.parameters = parameters
self.fuser_class: EvidenceFuser.__class__ = fuser_class

self.stop_inq = 50 # TODO pull from parameters
self.max_inq_len = self.parameters.get('max_inq_len', 50)
# TODO add stoppage criterion, Stoppage criterion is seperate from decision. Decision should we go on to next letter or not

def is_done(self) -> bool:

return self.state.total_inquiry_count() > self.stop_inq or self.state.target_sentence == self.state.current_sentence or self.state.series_n > 50
return self.state.total_inquiry_count() > self.max_inq_len or self.state.target_sentence == self.state.current_sentence or self.state.series_n > 50

def update(self, evidence: np.ndarray) -> InquiryResult:
def update(self, evidence) -> InquiryResult:

fuser = self.fuser_class()
current_series: List[InquiryResult] = self.state.series_results[self.state.series_n]
prior_likelihood: Optional[np.ndarray] = current_series.pop().fused_likelihood if current_series else None # most recent likelihood
evidence_dict = {"SM": evidence} # TODO create wrapper object for Evidences
fused_likelihood = fuser.fuse(prior_likelihood, evidence_dict)

temp_inquiry_result = InquiryResult(target=self.state.target_symbol, time_spent=0, stimuli=self.state.display_alphabet,
evidence_likelihoods=list(evidence), decision=None)

# finding out whether max iterations is hit or prob threshold is hit
temp_inquiry_result = InquiryResult(target=self.state.target_symbol, time_spent=0, stimuli=self.state.display_alphabet,
evidence_likelihoods=list(evidence), fused_likelihood=fused_likelihood, # TODO change to use evidence_dict
decision=None)

temp_series = copy.deepcopy(self.get_state().series_results)
temp_series[-1].append(temp_inquiry_result)
is_decidable = any([decider.decide(temp_series[-1]) for decider in self.state.decision_criterion])
decision = None
# TODO what to do when max inquiry count is reached?

new_state = self.get_state().__dict__
if is_decidable:
decision = alphabet()[np.argmax(evidence)] # deciding the maximum probability symbol TODO abstract

if decision == self.state.target_symbol: # correct decision
new_state['series_n'] += 1 # TODO abstract out into reset function
new_state['series_results'].append([])
Expand All @@ -99,7 +109,7 @@ def update(self, evidence: np.ndarray) -> InquiryResult:
new_state['inquiry_n'] += 1

new_inquiry_result = InquiryResult(target=self.state.target_symbol, time_spent=0, stimuli=self.state.display_alphabet,
evidence_likelihoods=list(evidence), decision=decision)
evidence_likelihoods=list(evidence), decision=decision, fused_likelihood=fused_likelihood)

new_state['series_results'][self.state.series_n].append(new_inquiry_result)

Expand All @@ -122,8 +132,11 @@ def mutate_state(self, state_field, state_value) -> SimState:
@staticmethod
def initial_state(parameters: Parameters = None) -> SimState:
sentence = "HELLO_WORLD" # TODO abstract out with sim_parameters.json
target_symbol = sentence[0]
target_symbol = sentence[0] # TODO use parameters.get('spelled_letters_count')
default_criterion: List[SimDecisionCriteria] = [MaxIterationsSim(50), ProbThresholdSim(0.8)]

evidence_types = parameters.get(
'evidence_types') if parameters else None # TODO make new parameter and create default series_likelihoods object based off that

return SimState(target_symbol=target_symbol, current_sentence="", target_sentence=sentence, display_alphabet=[], inquiry_n=0, series_n=0,
series_results=[[]], decision_criterion=default_criterion)
5 changes: 4 additions & 1 deletion bcipy/simulator/helpers/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from dataclasses import dataclass
from typing import Optional, List

import numpy as np


@dataclass
class InquiryResult:
target: Optional[str]
time_spent: int # TODO what does time_spent mean?
stimuli: List
evidence_likelihoods: List
evidence_likelihoods: List # TODO make this into a dictionary to support multimodal. e.g {SignalModel: evidence_list, LanguageModel:evidence_list}
fused_likelihood: np.ndarray
decision: Optional[str]
7 changes: 5 additions & 2 deletions bcipy/simulator/sim_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,18 @@ def __init__(self, data_engine: DataEngine, model_handler: ModelHandler, sampler

def run(self):
while not self.state_manager.is_done():
print(f"Series {self.state_manager.get_state().series_n} | Inquiry {self.state_manager.get_state().inquiry_n}")
print(f"Series {self.state_manager.get_state().series_n} | Inquiry {self.state_manager.get_state().inquiry_n} | Target {self.state_manager.get_state().target_symbol}")
self.state_manager.mutate_state('display_alphabet', self.__get_inquiry_alp_subset(self.state_manager.get_state()))
sampled_data = self.sampler.sample(self.state_manager.get_state())
evidence = self.model_handler.generate_evidence(self.state_manager.get_state(), sampled_data)
evidence = self.model_handler.generate_evidence(self.state_manager.get_state(),
sampled_data) # TODO make this evidence be a dict (mapping of evidence type to evidence)

print(f"Evidence for stimuli {self.state_manager.get_state().display_alphabet} \n {evidence}")

inq_record: InquiryResult = self.state_manager.update(evidence)

print(f"Fused Likelihoods {[str(round(p, 3)) for p in inq_record.fused_likelihood]}")

if inq_record.decision:
print(f"Decided {inq_record.decision} for target {inq_record.target} for sentence {self.state_manager.get_state().target_sentence}")

Expand Down

0 comments on commit bcd3a2c

Please sign in to comment.