Skip to content

Commit

Permalink
Merge branch 'master' into migrate_reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwonyang authored Sep 12, 2024
2 parents 82d2717 + 3a3e50b commit d10a8bf
Show file tree
Hide file tree
Showing 86 changed files with 3,132 additions and 2,450 deletions.
68 changes: 50 additions & 18 deletions compiler/fm-equalize/fm-equalize
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def _get_parser():
help="Allow to create duplicate operations when a feature map matches "
"with multiple equalization patterns. This can increase the size of "
"the model. Default is false.")
parser.add_argument("--fme_detect",
type=str,
help="Path to fme-detect driver.",
required=False)
parser.add_argument("--dalgona",
type=str,
help="Path to dalgona driver.",
required=False)
parser.add_argument("--fme_apply",
type=str,
help="Path to fme-apply driver.",
required=False)
parser.add_argument('--verbose', action='store_true', help='Print logs')

return parser
Expand All @@ -78,12 +90,9 @@ def _run_cmd(cmd: str, verbose: bool):
raise


def _run_dalgona(model: str, data: Optional[str], analysis: str, save_dir: str,
verbose: bool):
dir_path = os.getenv('ONE_BIN_PATH')
assert dir_path != None
dalgona_path = os.path.join(dir_path, 'dalgona')
cmd = [dalgona_path]
def _run_dalgona(driver_path: str, model: str, data: Optional[str], analysis: str,
save_dir: str, verbose: bool):
cmd = [driver_path]
cmd += ['--input_model', model]
cmd += ['--analysis', analysis]
if data != None:
Expand All @@ -94,11 +103,9 @@ def _run_dalgona(model: str, data: Optional[str], analysis: str, save_dir: str,
_run_cmd(cmd, verbose)


def _run_fme_detect(input_model: str, fme_patterns: str, verbose: bool,
def _run_fme_detect(driver_path: str, input_model: str, fme_patterns: str, verbose: bool,
allow_dup_op: bool):
dir_path = Path(__file__).parent.resolve()
fme_detect_path = os.path.join(dir_path, 'fme-detect')
cmd = [fme_detect_path]
cmd = [driver_path]
cmd += ['--input', input_model]
cmd += ['--output', fme_patterns]
if allow_dup_op:
Expand All @@ -107,10 +114,9 @@ def _run_fme_detect(input_model: str, fme_patterns: str, verbose: bool,
_run_cmd(cmd, verbose)


def _run_fme_apply(input_model: str, fme_patterns: str, output_model: str, verbose: bool):
dir_path = Path(__file__).parent.resolve()
fme_apply_path = os.path.join(dir_path, 'fme-apply')
cmd = [fme_apply_path]
def _run_fme_apply(driver_path: str, input_model: str, fme_patterns: str,
output_model: str, verbose: bool):
cmd = [driver_path]
cmd += ['--input', input_model]
cmd += ['--fme_patterns', fme_patterns]
cmd += ['--output', output_model]
Expand All @@ -128,14 +134,34 @@ def main():
data = args.data
verbose = args.verbose
allow_dup_op = args.allow_dup_op
fme_detect_path = args.fme_detect
fme_apply_path = args.fme_apply
dalgona_path = args.dalgona

curr_dir = Path(__file__).parent.resolve()
dump_fme_param_py = curr_dir / 'fmelib' / 'DumpFMEParams.py'
if dump_fme_param_py.exists() == False:
raise FileNotFoundError('Error: DumpFMEParams.py not found')

if not fme_detect_path:
dir_path = Path(__file__).parent.resolve()
fme_detect_path = os.path.join(dir_path, 'fme-detect')
if not dalgona_path:
dir_path = os.getenv('ONE_BIN_PATH')
assert dir_path != None
dalgona_path = os.path.join(dir_path, 'dalgona')
if not fme_apply_path:
dir_path = Path(__file__).parent.resolve()
fme_apply_path = os.path.join(dir_path, 'fme-apply')

with tempfile.TemporaryDirectory() as tmp_dir:
fme_patterns = os.path.join(
tmp_dir,
Path(output_model).with_suffix('.fme_patterns.json').name)

# Step 1. Run fme-detect to find equalization patterns
_run_fme_detect(str(input_model),
_run_fme_detect(fme_detect_path,
str(input_model),
str(fme_patterns),
verbose=verbose,
allow_dup_op=allow_dup_op)
Expand All @@ -144,16 +170,22 @@ def main():
if args.fme_patterns != None:
os.system(f'cp {fme_patterns} {args.fme_patterns}')

# TODO Step 2. Run dalgona
# _run_dalgona
# Step 2. Run dalgona
_run_dalgona(dalgona_path,
str(input_model),
data,
str(dump_fme_param_py),
str(fme_patterns),
verbose=verbose)

# Copy fme_patterns to the given path
# Why copy twice? To observe the result of fme-detect too
if args.fme_patterns != None:
os.system(f'cp {fme_patterns} {args.fme_patterns}')

# Step 3. Run fme-apply
_run_fme_apply(str(input_model),
_run_fme_apply(fme_apply_path,
str(input_model),
str(fme_patterns),
str(output_model),
verbose=verbose)
Expand Down
10 changes: 0 additions & 10 deletions compiler/fme-apply/src/FMEqualizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@
#include "FMEqualizer.h"
#include "InsertScaleShift.h"
#include "EqualizePatternCheck.h"
#include "pass/ForwardPreScalePass.h"
#include "pass/ForwardPreShiftPass.h"
#include "pass/FusePostScalePass.h"
#include "pass/FusePostShiftPass.h"
#include "pass/FusePreScalePass.h"
#include "pass/FusePreShiftPass.h"
#include "ProgressReporter.h"

#include <luci/IR/CircleNode.h>
Expand Down Expand Up @@ -82,15 +78,9 @@ void FMEqualizer::equalize(loco::Graph *g, const std::vector<EqualizePattern> &p
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());

// Forward PreScale/PreShift
phase.emplace_back(std::make_unique<fme_apply::ForwardPreScalePass>());
phase.emplace_back(std::make_unique<fme_apply::ForwardPreShiftPass>());

// Fuse Pre/Post Scale/Shift
phase.emplace_back(std::make_unique<fme_apply::FusePreScalePass>());
phase.emplace_back(std::make_unique<fme_apply::FusePostScalePass>());
phase.emplace_back(std::make_unique<fme_apply::FusePreShiftPass>());
phase.emplace_back(std::make_unique<fme_apply::FusePostShiftPass>());

ProgressReporter prog(g, logo::PhaseStrategy::Restart);
logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
Expand Down
1 change: 1 addition & 0 deletions compiler/fme-apply/src/RandomString.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef __FME_APPLY_RANDOM_STRING_H__
#define __FME_APPLY_RANDOM_STRING_H__

#include <cstdint>
#include <string>

namespace fme_apply
Expand Down
95 changes: 0 additions & 95 deletions compiler/fme-apply/src/pass/ForwardPreScalePass.cpp

This file was deleted.

53 changes: 0 additions & 53 deletions compiler/fme-apply/src/pass/ForwardPreScalePass.h

This file was deleted.

Loading

0 comments on commit d10a8bf

Please sign in to comment.