Skip to content

Commit

Permalink
Merge pull request #597 from oceanbase/3.0.0-dev
Browse files Browse the repository at this point in the history
merge 3.0
  • Loading branch information
Teingi authored Dec 6, 2024
2 parents 2b036d7 + 0810806 commit e052cbc
Show file tree
Hide file tree
Showing 60 changed files with 1,240 additions and 274 deletions.
10 changes: 5 additions & 5 deletions common/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def get_observer_version(context):
stdio.verbose("get observer version, by sql")
obcluster = context.cluster_config
# by sql
observer_version = get_observer_version_by_sql(obcluster, stdio)
observer_version = get_observer_version_by_sql(context, obcluster)
except Exception as e:
try:
stdio.verbose("get observer version, by sql fail. by ssh")
Expand Down Expand Up @@ -325,15 +325,15 @@ def get_obproxy_version(context):
# Only applicable to the community version


def get_observer_version_by_sql(ob_cluster, stdio=None):
stdio.verbose("start get_observer_version_by_sql . input: {0}:{1}".format(ob_cluster.get("db_host"), ob_cluster.get("db_port")))
def get_observer_version_by_sql(context, ob_cluster):
context.stdio.verbose("start get_observer_version_by_sql . input: {0}:{1}".format(ob_cluster.get("db_host"), ob_cluster.get("db_port")))
try:
ob_connector = OBConnector(ip=ob_cluster.get("db_host"), port=ob_cluster.get("db_port"), username=ob_cluster.get("tenant_sys").get("user"), password=ob_cluster.get("tenant_sys").get("password"), stdio=stdio, timeout=100)
ob_connector = OBConnector(context=context, ip=ob_cluster.get("db_host"), port=ob_cluster.get("db_port"), username=ob_cluster.get("tenant_sys").get("user"), password=ob_cluster.get("tenant_sys").get("password"), timeout=100)
ob_version_info = ob_connector.execute_sql("select version();")
except Exception as e:
raise Exception("get_observer_version_by_sql Exception. Maybe cluster'info is error: " + e.__str__())
ob_version = ob_version_info[0]
stdio.verbose("get_observer_version_by_sql ob_version_info is {0}".format(ob_version))
context.stdio.verbose("get_observer_version_by_sql ob_version_info is {0}".format(ob_version))
version = re.findall(r'OceanBase(_)?(.CE)?-v(.+)', ob_version[0])
if len(version) > 0:
return version[0][2]
Expand Down
4 changes: 2 additions & 2 deletions common/config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, context):

def get_cluster_name(self):
ob_version = get_observer_version(self.context)
obConnetcor = OBConnector(ip=self.db_host, port=self.db_port, username=self.sys_tenant_user, password=self.sys_tenant_password, stdio=self.stdio, timeout=100)
obConnetcor = OBConnector(context=self.context, ip=self.db_host, port=self.db_port, username=self.sys_tenant_user, password=self.sys_tenant_password, timeout=100)
if ob_version.startswith("3") or ob_version.startswith("2"):
sql = "select cluster_name from oceanbase.v$ob_cluster"
res = obConnetcor.execute_sql(sql)
Expand All @@ -68,7 +68,7 @@ def get_cluster_name(self):

def get_host_info_list_by_cluster(self):
ob_version = get_observer_version(self.context)
obConnetcor = OBConnector(ip=self.db_host, port=self.db_port, username=self.sys_tenant_user, password=self.sys_tenant_password, stdio=self.stdio, timeout=100)
obConnetcor = OBConnector(context=self.context, ip=self.db_host, port=self.db_port, username=self.sys_tenant_user, password=self.sys_tenant_password, timeout=100)
sql = "select SVR_IP, SVR_PORT, ZONE, BUILD_VERSION from oceanbase.DBA_OB_SERVERS"
if ob_version.startswith("3") or ob_version.startswith("2") or ob_version.startswith("1"):
sql = "select SVR_IP, SVR_PORT, ZONE, BUILD_VERSION from oceanbase.__all_server"
Expand Down
5 changes: 3 additions & 2 deletions common/ob_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ class OBConnector(object):

def __init__(
self,
context,
ip,
port,
username,
password=None,
database=None,
stdio=None,
timeout=30,
):
self.context = context
self.ip = str(ip)
self.port = int(port)
self.username = str(username)
self.password = str(password)
self.timeout = timeout
self.conn = None
self.stdio = stdio
self.stdio = context.stdio
self.database = database
self.init()

Expand Down
2 changes: 1 addition & 1 deletion common/ssh_client/local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,4 @@ def get_name(self):
return "local"

def get_ip(self):
return self.client.get_ip()
return "127.0.0.1"
3 changes: 3 additions & 0 deletions common/ssh_client/remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,6 @@ def ssh_invoke_shell_switch_user(self, new_user, cmd, time_out):

def get_name(self):
return "remote_{0}".format(self.host_ip)

def get_ip(self):
return self.host_ip
1 change: 1 addition & 0 deletions common/ssh_client/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def exec_cmd(self, cmd):
return self.client.exec_cmd(cmd).strip()

def download(self, remote_path, local_path):
self.stdio.verbose("download file: {} to {}".format(remote_path, local_path))
return self.client.download(remote_path, local_path)

def upload(self, remote_path, local_path):
Expand Down
18 changes: 10 additions & 8 deletions common/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def unzip(source, ztype=None, stdio=None):
stdio and getattr(stdio, 'exception', print)('failed to unzip %s' % source)
return None

@staticmethod
def extract_tar(tar_path, output_path, stdio=None):
if not os.path.exists(output_path):
os.makedirs(output_path)
Expand Down Expand Up @@ -578,6 +579,7 @@ def unlock(obj, stdio=None):
fcntl.flock(obj, fcntl.LOCK_UN)
return obj

@staticmethod
def size_format(num, unit="B", output_str=False, stdio=None):
if num < 0:
raise ValueError("num cannot be negative!")
Expand Down Expand Up @@ -637,6 +639,7 @@ def calculate_sha256(filepath, stdio=None):
except Exception as e:
return ""

@staticmethod
def size(size_str, unit='B', stdio=None):
unit_size_dict = {
"b": 1,
Expand All @@ -658,10 +661,12 @@ def size(size_str, unit='B', stdio=None):
raise ValueError('size cannot be negative!')
return real_size / unit_size_dict[unit]

@staticmethod
def write_append(filename, result, stdio=None):
with io.open(filename, 'a', encoding='utf-8') as fileobj:
fileobj.write(u'{}'.format(result))

@staticmethod
def tar_gz_to_zip(temp_dir, tar_gz_file, output_zip, password, stdio):
extract_dir = os.path.join(temp_dir, 'extracted_files_{0}'.format(str(uuid.uuid4())[:6]))

Expand All @@ -682,20 +687,14 @@ def tar_gz_to_zip(temp_dir, tar_gz_file, output_zip, password, stdio):
base_paths.append(base_path)
stdio.verbose("start pyminizip compress_multiple")
# 3. Compress the extracted files into a (possibly) encrypted zip file
zip_process = None
if password:
# Use pyminizip to create the encrypted zip file
zip_process = mp.Process(target=pyminizip.compress_multiple, args=(files_to_compress, base_paths, output_zip, password, 5))
# pyminizip.compress_multiple(files_to_compress, base_paths, output_zip, password, 5) # 5 is the compression level
pyminizip.compress_multiple(files_to_compress, base_paths, output_zip, password, 5) # 5 is the compression level
stdio.verbose("extracted files compressed into encrypted {0}".format(output_zip))
else:
# Create an unencrypted zip file
zip_process = mp.Process(target=pyminizip.compress_multiple, args=(files_to_compress, base_paths, output_zip, None, 5))
# pyminizip.compress_multiple(files_to_compress, base_paths, output_zip, None, 5)
pyminizip.compress_multiple(files_to_compress, base_paths, output_zip, None, 5)
stdio.verbose("extracted files compressed into unencrypted {0}".format(output_zip))
zip_process.start()
if zip_process is not None:
zip_process.join()

# 4. Remove the extracted directory
shutil.rmtree(extract_dir)
Expand Down Expand Up @@ -1236,6 +1235,8 @@ def parse_env(env_string, stdio=None):
@staticmethod
def parse_env_display(env_list):
env_dict = {}
if not env_list:
return {}
for env_string in env_list:
# 分割键和值
key_value = env_string.split('=', 1)
Expand Down Expand Up @@ -1580,6 +1581,7 @@ def print_title(name, stdio):
def gen_password(length=8, chars=string.ascii_letters + string.digits, stdio=None):
return ''.join([choice(chars) for i in range(length)])

@staticmethod
def retry(retry_count=3, retry_interval=2, stdio=None):
def real_decorator(decor_method):
def wrapper(*args, **kwargs):
Expand Down
32 changes: 32 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,38 @@ def create_ob_proxy_node(node_config, global_config):
'servers': ob_proxy_nodes,
}

@property
def get_oms_config(self):
oms = self.config_data.get('oms', {})
nodes = oms.get('servers', {}).get('nodes', [])

def create_oms_node(node_config, global_config):
return {
'ip': node_config.get('ip'),
'ssh_username': node_config.get('ssh_username', global_config.get('ssh_username', '')),
'ssh_password': node_config.get('ssh_password', global_config.get('ssh_password', '')),
'ssh_port': node_config.get('ssh_port', global_config.get('ssh_port', 22)),
'home_path': node_config.get('home_path', global_config.get('home_path', '/root/obproxy')),
'log_path': node_config.get('log_path', global_config.get('log_path', '/home/admin/logs')),
'run_path': node_config.get('run_path', global_config.get('run_path', '/home/admin/run')),
'store_path': node_config.get('store_path', global_config.get('store_path', '/home/admin/store')),
'ssh_key_file': node_config.get('ssh_key_file', global_config.get('ssh_key_file', '')),
'ssh_type': node_config.get('ssh_type', global_config.get('ssh_type', 'remote')),
'container_name': node_config.get('container_name', global_config.get('container_name')),
'namespace': node_config.get('namespace', global_config.get('namespace', '')),
'pod_name': node_config.get('pod_name', global_config.get('pod_name', '')),
"kubernetes_config_file": node_config.get('kubernetes_config_file', global_config.get('kubernetes_config_file', '')),
'host_type': 'OMS',
}

global_config = oms.get('servers', {}).get('global', {})
oms_nodes = [create_oms_node(node, global_config) for node in nodes]

return {
'oms_cluster_name': oms.get('oms_cluster_name'),
'servers': oms_nodes,
}

@property
def get_node_config(self, type, node_ip, config_item):
if type == 'ob_cluster':
Expand Down
3 changes: 2 additions & 1 deletion context.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def return_false(self, *args, **kwargs):

class HandlerContext(object):

def __init__(self, handler_name=None, namespace=None, namespaces=None, cluster_config=None, obproxy_config=None, ocp_config=None, inner_config=None, cmd=None, options=None, stdio=None):
def __init__(self, handler_name=None, namespace=None, namespaces=None, cluster_config=None, obproxy_config=None, oms_config=None, ocp_config=None, inner_config=None, cmd=None, options=None, stdio=None):
self.namespace = HandlerContextNamespace(namespace)
self.namespaces = namespaces
self.handler_name = handler_name
self.cluster_config = cluster_config
self.obproxy_config = obproxy_config
self.oms_config = oms_config
self.ocp_config = ocp_config
self.inner_config = inner_config
self.cmds = cmd
Expand Down
Loading

0 comments on commit e052cbc

Please sign in to comment.