Skip to content

Commit

Permalink
mmscan-devkit v1
Browse files Browse the repository at this point in the history
  • Loading branch information
rbler1234 committed Nov 23, 2024
1 parent 6c76334 commit 47092fc
Show file tree
Hide file tree
Showing 28 changed files with 669 additions and 3,079 deletions.
13 changes: 6 additions & 7 deletions data_preparation/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
### Prepare point clouds info files.
### Prepare MMscan info files.

Given the licenses of respective raw datasets, we recommend users download the raw data from their official websites and then organize them following the below guide.
Detailed steps are shown as follows.
Expand All @@ -9,8 +9,8 @@ Detailed steps are shown as follows.

3. Download Matterport3D data [HERE](https://github.com/niessner/Matterport). Link or move the folder to this level of directory.

4. Organize the file structure. Under `mmscan_data/embodiedscan-split/embodiedscan-v1`, the directory structure should be as below,
You are recommanded to create a soft link to the raw data folder under `mmsan_data/embodiedscan-split/embodiedscan-v1`.
4. Organize the file structure. Under `mmscan_data/embodiedscan-split`, the directory structure should be as below,
You are recommanded to create a soft link to the raw data folder under `mmsan_data/embodiedscan-split`.

```
data/
Expand All @@ -29,16 +29,15 @@ Detailed steps are shown as follows.
Additionally, create a `process_pcd` folder in the same directory to store the results. Similarly, we recommend using a symbolic link, as the total file size might be a little large (approximately 21GB)

PS: If you have followed the embodiedscan tutorial to organize the data, you can skip these steps and link or copy the `data` folder to
`mmsan_data/embodiedscan-split/embodiedscan-v1`.
`mmsan_data/embodiedscan-split`.

After all the raw data is organized, the directory structure should be as below:

```
embodiedscan-v1/
embodiedscan-split/
├── data/
├── process_pcd/
├── embodiedscan_infos_train.pkl
├── embodiedscan_infos_val.pkl
├── embodiedscan-v1/
```

5. Read raw files and generate processed point cloud files, by running the following scripts.
Expand Down
67 changes: 33 additions & 34 deletions data_preparation/process_all_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.scannet_process import process_scannet
from utils.trscan_process import process_trscan

dict_1 = {}
es_anno = {}


def create_scene_pcd(es_anno, pcd_result):
Expand All @@ -23,23 +23,18 @@ def create_scene_pcd(es_anno, pcd_result):
Args:
es_anno (dict): The embodiedscan annotation of
the target scan.
pcd_result (tuple) :
(1) aliged point clouds coordinates
shape (n,3)
(2) point clouds color ([0,1])
shape (n,3)
(3) label (no need here)
pcd_result (tuple) : The raw point cloud data of the scan,
consisting of:
(1) aliged point clouds coordinates with shape (n,3).
(2) point clouds color ([0,1]) with shape (n,3).
(3) label (no need here).
Returns:
tuple :
(1) aliged point clouds coordinates
shape (n,3)
(2) point clouds color ([0,1])
shape (n,3)
(3) point clouds label (int)
shape (n,1)
(4) point clouds object id (int)
shape (n,1)
tuple : The processed point cloud data of the scan, consisting of:
(1) aliged point clouds coordinates with shape (n,3).
(2) point clouds color ([0,1]) with shape (n,3).
(3) point clouds label with shape (n,1).
(4) point clouds object id (int) with shape (n,1).
"""
pc, color, label = pcd_result
label = np.ones_like(label) * -100
Expand Down Expand Up @@ -86,29 +81,33 @@ def process_one_scan(
):
"""Process the point clouds of one scan and save in a pth file.
The pth file is a tuple of:
(1) aliged point clouds coordinates
shape (n,3)
(2) point clouds color ([0,1])
shape (n,3)
(3) point clouds label (int)
shape (n,1)
(4) point clouds object id (int)
shape (n,1)
The pth file is a tuple of nd.array, consisting of:
(1) aliged point clouds coordinates with shape (n,3).
(2) point clouds color ranging in [0,1] with shape (n,3).
(3) point clouds label with shape (n,1).
(4) point clouds object id with shape (n,1).
Args:
scan_id (str): the scan id
scan_id (str): The scan id.
save_root (str): The root path to save the pth file.
scannet_root (str): The path of scannet.
mp3d_root (str): The path of mp3d.
trscan_root (str): The path of 3rscan.
scannet_matrix (nd.array): The aligned matrix of scannet.
mp3d_matrix (nd.array): The aligned matrix of mp3d.
trscan_matrix (nd.array): The aligned matrix of 3rscan.
mp3d_mapping (dict): The mapping dict for mp3d scan id.
"""

if os.path.exists(f'{save_root}/{scan_id}.pth'):
return

try:
if 'scene' in scan_id:
if 'scannet/' + scan_id not in dict_1:
if 'scannet/' + scan_id not in es_anno:
return

pcd_info = create_scene_pcd(
dict_1['scannet/' + scan_id],
es_anno['scannet/' + scan_id],
process_scannet(scan_id, scannet_root, scannet_matrix),
)

Expand All @@ -118,19 +117,19 @@ def process_one_scan(
'region' + scan_id.split('_region')[1],
)
mapping_name = f'matterport3d/{raw_scan_id}/{region_id}'
if mapping_name not in dict_1:
if mapping_name not in es_anno:
return

pcd_info = create_scene_pcd(
dict_1[mapping_name],
es_anno[mapping_name],
process_mp3d(scan_id, mp3d_root, mp3d_matrix, mp3d_mapping),
)

else:
if '3rscan/' + scan_id not in dict_1:
if '3rscan/' + scan_id not in es_anno:
return
pcd_info = create_scene_pcd(
dict_1['3rscan/' + scan_id],
es_anno['3rscan/' + scan_id],
process_trscan(scan_id, trscan_root, trscan_matrix),
)

Expand Down Expand Up @@ -182,8 +181,8 @@ def process_one_scan(

TYPE2INT = np.load(args.train_pkl_path,
allow_pickle=True)['metainfo']['categories']
dict_1.update(read_annotation_pickle(args.train_pkl_path))
dict_1.update(read_annotation_pickle(args.val_pkl_path))
es_anno.update(read_annotation_pickle(args.train_pkl_path))
es_anno.update(read_annotation_pickle(args.val_pkl_path))

# loading the required scan id
with open(f'{args.meta_path}/all_scan.json', 'r') as f:
Expand Down
2 changes: 1 addition & 1 deletion data_preparation/utils/scannet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def process_scannet(scan_id, data_root, scannet_matrix):
r = np.asarray(data_color.elements[0].data['red'])
g = np.asarray(data_color.elements[0].data['green'])
b = np.asarray(data_color.elements[0].data['blue'])
pc_color = (np.stack([r, g, b], axis=1) / 256.0).astype(np.float32)
pc_color = (np.stack([r, g, b], axis=1) / 255.0).astype(np.float32)
axis_align_matrix = scannet_matrix[scan_id]
pts = np.ones((pc.shape[0], 4), dtype=pc.dtype)
pts[:, :3] = pc
Expand Down
27 changes: 14 additions & 13 deletions mmscan/evaluator/gpt_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def normal_query(self,
The system prompt inputted into GPT.
user_content_grounps (list[str]) :
The user content inputted into GPT.
max_tokens (int) : max tokens, default 1000.
max_tokens (int) : Max tokens. Defaults to 1000.
Returns:
dict : the json-format result.
dict : The json-format result.
"""

messages = []
Expand All @@ -77,13 +77,11 @@ def qa_evaluation(self, QA_sample_dict, thread_index, tmp_path):
"""Employ the GPT evaluator.
Args:
QA_sample_dict (str) :
The system prompt inputted into GPT.
user_content_grounps (list[str]) :
The user content inputted into GPT.
max_tokens (int) : max tokens, default 1000.
Returns:
dict : the json-format result.
QA_sample_dict (str) : The QA sample dict with
[gt, pred, question] as values.
thread_index (int) : The index of the thread.
tmp_path (str) : The path to store the
tmp-stored json files.
"""

system_prompt, ex_instance = qa_prompt_define()
Expand Down Expand Up @@ -137,7 +135,7 @@ def qa_collection(self, num_threads, tmp_path):
tmp_path (str) :
The path to store the tmp-stored json files.
Returns:
dict : the evaluation result.
dict : The evaluation result.
"""

eval_dict = {metric: [] for metric in self.qa_metric}
Expand Down Expand Up @@ -174,12 +172,12 @@ def load_and_eval(self, raw_batch_input, num_threads=1, tmp_path='./'):
Args:
raw_batch_input (list[dict]) :
the batch of results wanted to evaluate
The batch of results wanted to evaluate
num_threads (int) : The number of the threadings.
Defaults to 1.
tmp_path (str) : The temporay path to store the json files.
Returns:
dict : the evaluation result.
dict : The evaluation result.
"""

# (1) Update the results and store in the dict.
Expand Down Expand Up @@ -235,7 +233,10 @@ def __check_format__(self, raw_input):
to be checked, should be a list of dict. Every item with the keys:
["ID","question","pred",""gt"] pred is a list with one one element. gt
is a list with >=1 elements. "ID" should be unique!!!!
is a list with >=1 elements. "ID" should be unique.
Args:
raw_input (list[dict]) : The input to be checked.
"""
assert isinstance(
raw_input,
Expand Down
6 changes: 4 additions & 2 deletions mmscan/evaluator/metrics/box_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def average_precision(recalls, precisions, mode='area'):
mode (str): 'area' or '11points', 'area' means calculating the area
under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1]
Defaults to 'area'.
Returns:
float or np.ndarray: Calculated average precision.
Expand Down Expand Up @@ -57,7 +58,8 @@ def get_f1_scores(iou_matrix, iou_threshold):
Args:
iou_matrix (ndarray/tensor):
the iou matrix of the predictions and ground truths (shape n*m)
The iou matrix of the predictions and ground truths with
shape (num_preds , num_gts)
iou_threshold (float): 0.25/0.5
Returns:
Expand Down Expand Up @@ -93,7 +95,7 @@ def __get_fp_tp_array__(iou_array, iou_threshold):
Args:
iou_array (ndarray/tensor):
the iou matrix of the predictions and ground truths
(shape len(preds)*len(gts))
(shape num_preds, num_gts)
iou_threshold (float): 0.25/0.5
Returns:
Expand Down
34 changes: 16 additions & 18 deletions mmscan/evaluator/qa_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@


class QA_Evaluator:
"""tradition metrics for QA and Caption evaluation , consists the
"""Tradition metrics for QA and Caption evaluation , consists the
implements of.
[EM, BLEU, METEOR, ROUGE, CIDEr, SPICE, SIMCSE, SBERT]
SIMCSE, SBERT is speacial metrics and needed GPU tools.
SIMCSE, SBERT is speacial metrics and needed GPU.
Attributes:
save_buffer(list[dict]): Save the buffer of Inputs
records(list[dict]): Metric results for each sample
metric_record(dict): Metric results for each category
save_buffer(list[dict]): Save the buffer of Inputs.
records(list[dict]): Metric results for each sample.
metric_record(dict): Metric results for each category.
(average of all samples with the same category)
Args:
model_config(dict): The model config for special metric evaluation.
Expand Down Expand Up @@ -67,18 +67,18 @@ def update(self, batch_input):
"""Update a batch of results to the buffer, and then filtering and
truncating. each item is expected to be a dict with keys.
["index", "ID","question","pred",""gt"]
["index", "ID","question","pred","gt"]
1. pred is a list with one one element.
2. gt is a list with >=1 elements.
3. "ID" should be unique!!!!
3. "ID" should be unique.
Args:
batch_input (list[dict]):
a batch of the raw original input
Batch of the raw original input.
Returns:
Dict: {"EM":EM metric for this batch,
"refined_EM":refined EM metric for this batch}
"refined_EM":Refined EM metric for this batch}
"""

self.__check_format__(batch_input)
Expand Down Expand Up @@ -112,7 +112,7 @@ def start_evaluation(self):
"""Start the evaluation process.
Returns:
dict: the metrics
dict: The results of the evaluation.
"""

# (1) exact match evaluation
Expand Down Expand Up @@ -170,18 +170,16 @@ def start_evaluation(self):

def __check_format__(self, raw_input):
"""Check if the input conform with mmscan evaluation format.
Args:
The input to be checked, should be a list of dict.
Every item with the keys:
["index", "ID","question","pred",""gt"]
pred is a list with one one element.
gt is a list with >=1 elements.
"ID" should be unique!!!!
Every item with the keys ["index", "ID","question","pred","gt"],
'pred' is a list with one one element, 'gt' is a list
with >=1 elements. "ID" should be unique.
Args:
raw_input (list[dict]): The input to be checked.
"""
assert isinstance(
raw_input,
list), 'The input of MMScan evaluator should be a list of dict. '
list), 'The input of QA evaluator should be a list of dict. '

for _index in range(len(raw_input)):
if 'index' not in raw_input[_index]:
Expand Down
Loading

0 comments on commit 47092fc

Please sign in to comment.