-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add more quantization strategies to contrib.epi #2440
Conversation
pyro/contrib/epidemiology/util.py
Outdated
elif num_quant_bins == 16: | ||
global w16 | ||
if w16.device != s.device: | ||
w16 = w16.to(s.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a better way to deal with this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is really bad, you're updating the global variable. How about renaming the global variable to W16
and then using w16 = s.new_tensor(W16)
.
@@ -33,8 +33,9 @@ | |||
'contrib/autoname/mixture.py --num-epochs=1', | |||
'contrib/autoname/tree_data.py --num-epochs=1', | |||
'contrib/cevae/synthetic.py --num-epochs=1', | |||
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2', | |||
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct=1', | |||
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i changed some of the args to reduce test time to below a minute
i still intend to add a basic sampler test |
""" | ||
|
||
def __init__(self, compartments, duration, population): | ||
def __init__(self, compartments, duration, population, num_quant_bins=4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a *,
separator to force use as a kwarg rather than arg?
def __init__(self, compartments, duration, population, *,
num_quant_bins=4):
This gives us flexibility to later (1) reorder all kwargs after the *,
, or even (2) lump them into a **kwargs
dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -91,6 +93,10 @@ def __init__(self, compartments, duration, population): | |||
assert len(compartments) == len(set(compartments)) | |||
self.compartments = compartments | |||
|
|||
if num_quant_bins not in [4, 8, 12, 16]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please weaken this check and let the other file decide what values are allowed:
assert isininstance(num_quant_bins, int)
assert num_quant_bins > 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
pyro/contrib/epidemiology/sir.py
Outdated
""" | ||
|
||
def __init__(self, population, recovery_time, data): | ||
def __init__(self, population, recovery_time, data, num_quant_bins=4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: Add *,
separator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
pyro/contrib/epidemiology/util.py
Outdated
arange_min = - (num_quant_bins // 2 - 1) | ||
arange_max = num_quant_bins // 2 + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitnit: I guess you could write these as
arange_min = 1 - num_quant_bins // 2
arange_max = 1 + num_quant_bins // 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good after my minor comments.
Could you published the notebook deriving spline weights and to link to it in PR description? Feel free to push directly to my notebooks
repo or whatever. We can move that to a new repo like pyro-ppl/derivations
or something if you want.
my notebooks are a mess and i'd rather do this once i finalize which schemes i'd like to keep user facing |
y = torch.min(y, 2 * max + 1 - y) | ||
probs.scatter_add_(0, y, bin_probs[:, k] / num_samples) | ||
|
||
max_deviation = (probs - 1.0 / (max + 1.0)).abs().max().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in total these tests take about a second
Addresses #2426