Skip to content

Commit

Permalink
do not cache model outputs to save memory
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 13, 2024
1 parent 19e4f97 commit f95b77f
Showing 1 changed file with 0 additions and 5 deletions.
5 changes: 0 additions & 5 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def run_pipeline(
# prepare intermediates cache
model_device = get_execution_device(model)
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)
model_outputs = [dict() for _ in range(len(dataloader))]

num_subgraphs = len(subgraphs)
for subgraph_index, subgraph in enumerate(subgraphs):
Expand All @@ -64,7 +63,6 @@ def run_pipeline(
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(batch_index, subgraph.input_names)
forward_function(model, **inputs)
del inputs

# if using propagate_error, then this pass does not trigger modifier hooks
# and is only used for capturing intermediates
Expand All @@ -74,10 +72,7 @@ def run_pipeline(
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=desc):
inputs = intermediates.fetch(batch_index, subgraph.input_names)
output = forward_function(model, **inputs)
del inputs

if subgraph_index < len(subgraphs) - 1:
intermediates.update(batch_index, output)
intermediates.delete(batch_index, subgraph.consumed_names)
else:
model_outputs[batch_index] = output

0 comments on commit f95b77f

Please sign in to comment.