Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a process that causes the following error when inserting more than N data for the data to be synthesized? #302

Open
limhasic opened this issue Nov 6, 2024 · 0 comments

Comments

@limhasic
Copy link

limhasic commented Nov 6, 2024

Is there a process that causes the following error when inserting more than N data for the data to be synthesized?

Epoch: 0%| | 0/1 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [0,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [1,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [3,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [7,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [8,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
Epoch: 0%|

`# stdlib
import sys
import warnings

third party

import numpy as np
from sklearn.datasets import load_iris, load_diabetes

synthcity absolute

import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
import os

CUDA 비활성화

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

Note: preprocessing data with OneHotEncoder or StandardScaler is not needed or recommended. Synthcity handles feature encoding and standardization internally.

import pandas as pd
data = pd.read_csv('train.csv')#.drop('Unnamed: 0', axis=1)
data = data.drop('ID' ,axis = 1)
data = data[data['Fraud_Type'] == 'm'] # 왜지?

X = data.drop("Fraud_Type", axis = 1)
y = data["Fraud_Type"]
X["target"] = y

loader = GenericDataLoader(X, target_column="target", sensitive_columns=[])

모델 하이퍼파라미터 정의

plugin_params = dict(
is_classification=True,
n_iter=1, # epochs
lr=0.002,
weight_decay=1e-4,
batch_size=10,
model_type="mlp", # or "resnet"
model_params=dict(
n_layers_hidden=3,
n_units_hidden=256,
dropout=0.0,
),
num_timesteps=500, # timesteps in diffusion
dim_embed=128,
# 성능 로깅
log_interval=10,
)

plugin = Plugins().get("ddpm", **plugin_params)

plugin.fit(loader) # cond = subset_df["Race=Asian or Pacific Islander"] `

Data is from the competition below.
https://dacon.io/competitions/official/236297/codeshare

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant