Skip to content

Commit

Permalink
Fix training data handling (#58)
Browse files Browse the repository at this point in the history
- fixed the issue reported in #56 
- improves unzipping if no to be compiled qasm files are present

Co-authored-by: Lukas Burgholzer <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 26, 2023
1 parent ea9e8c6 commit fe11f36
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/mqt/predictor/ml/Predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,22 @@ def generate_compiled_circuits(
if target_path is None:
target_path = str(ml.helper.get_path_training_circuits_compiled())

source_circuits_list = []

for file in Path(source_path).iterdir():
if "qasm" in str(file):
source_circuits_list.append(str(file))

path_zip = Path(source_path) / "mqtbench_training_samples.zip"
if len(source_circuits_list) == 0 and path_zip.exists():
if (
not any(file.suffix == ".qasm" for file in Path(source_path).iterdir())
and path_zip.exists()
):
path_zip = str(path_zip)
import zipfile

with zipfile.ZipFile(path_zip, "r") as zip_ref:
zip_ref.extractall(source_path)

if not Path(source_path).is_dir():
Path(source_path).mkdir()
Path(target_path).mkdir(exist_ok=True)

source_circuits_list = [
file.name for file in Path(source_path).iterdir() if file.suffix == ".qasm"
]

Parallel(n_jobs=-1, verbose=100)(
delayed(self.compile_all_circuits_for_qc)(
Expand All @@ -174,7 +174,7 @@ def generate_trainingdata_from_qasm_files(
Keyword arguments:
source_path -- path to file
target_directory -- path to directory for compiled circuit
target_path -- path to directory for compiled circuit
Return values:
training_data_ML_aggregated -- training data
Expand All @@ -194,7 +194,7 @@ def generate_trainingdata_from_qasm_files(

results = Parallel(n_jobs=-1, verbose=100)(
delayed(self.generate_training_sample)(
str(filename), source_path, target_path
str(filename.name), source_path, target_path
)
for filename in Path(source_path).iterdir()
)
Expand Down
1 change: 1 addition & 0 deletions tests/ml/test_predictor_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_generate_compiled_circuits():
qasm_path = Path("compiled_test.qasm")
qc.qasm(filename=str(qasm_path))
predictor.generate_compiled_circuits(source_path, str(target_path))
assert any(file.suffix == ".qasm" for file in target_path.iterdir())

training_sample, circuit_name, scores = predictor.generate_training_sample(
str(qasm_path), source_path, target_path
Expand Down

0 comments on commit fe11f36

Please sign in to comment.