diff --git a/common/tool.py b/common/tool.py index 797d19d0..283f5679 100644 --- a/common/tool.py +++ b/common/tool.py @@ -1502,6 +1502,28 @@ def get_nodes_list(context, nodes, stdio=None): return new_nodes return None + @staticmethod + def check_none_values(config, stdio): + """ + Check if any values in the given configuration dictionary are None. + If any value is None, print the specific information and return False. + If all values are not None, return True. + + :param config: Dictionary containing configuration items + :return: True if no None values are found, otherwise False + """ + # First, check the top-level key-value pairs + for key, value in config.items(): + if value is None: + stdio.error("The value of '{0}' is None.".format(key)) + return False + + # If the value is a dictionary, recursively check the sub-dictionary + if isinstance(value, dict): + if not Util.check_none_values(value, stdio): + return False + return True + class SQLUtil(object): re_trace = re.compile(r'''\/\*.*trace_id((?!\/\*).)*rpc_id.*\*\/''', re.VERBOSE) diff --git a/config.py b/config.py index 6cbc1a92..60a650f1 100644 --- a/config.py +++ b/config.py @@ -163,6 +163,14 @@ def __init__(self, config_file=None, stdio=None, config_env_list=[]): parser = ConfigOptionsParserUtil() self.config_data = parser.parse_config(config_env_list) + def update_config_data(self, new_config_data, save_to_file=False): + if not isinstance(new_config_data, dict): + raise ValueError("new_config_data must be a dictionary") + self.config_data.update(new_config_data) + if save_to_file: + with open(self.config_file, 'w') as f: + yaml.dump(self.config_data, f, default_flow_style=False) + def _safe_get(self, dictionary, *keys, default=None): """Safe way to retrieve nested values from dictionaries""" current = dictionary diff --git a/core.py b/core.py index f96529da..46c22c11 100644 --- a/core.py +++ b/core.py @@ -58,7 +58,10 @@ from colorama import Fore, Style from common.config_helper import ConfigHelper -from common.tool import TimeUtils +from common.tool import TimeUtils, Util +from common.command import get_observer_version_by_sql +from common.ob_connector import OBConnector +from collections import OrderedDict class ObdiagHome(object): @@ -122,6 +125,7 @@ def _print(msg, *arg, **kwarg): self._stdio_func[func] = getattr(self.stdio, func, _print) def set_context(self, handler_name, namespace, config): + self.update_obcluster_nodes(config) self.context = HandlerContext( handler_name=handler_name, namespace=namespace, @@ -151,6 +155,36 @@ def set_context_skip_cluster_conn(self, handler_name, namespace, config): def set_offline_context(self, handler_name, namespace): self.context = HandlerContext(handler_name=handler_name, namespace=namespace, cmd=self.cmds, options=self.options, stdio=self.stdio, inner_config=self.inner_config_manager.config) + def update_obcluster_nodes(self, config): + config_data = config.config_data + cluster_config = config.config_data["obcluster"] + ob_cluster = {"db_host": cluster_config["db_host"], "db_port": cluster_config["db_port"], "tenant_sys": {"user": cluster_config["tenant_sys"]["user"], "password": cluster_config["tenant_sys"]["password"]}} + if Util.check_none_values(ob_cluster, self.stdio): + ob_version = get_observer_version_by_sql(ob_cluster, self.stdio) + obConnetcor = OBConnector(ip=ob_cluster["db_host"], port=ob_cluster["db_port"], username=ob_cluster["tenant_sys"]["user"], password=ob_cluster["tenant_sys"]["password"], stdio=self.stdio, timeout=100) + sql = "select SVR_IP, SVR_PORT, ZONE, BUILD_VERSION from oceanbase.DBA_OB_SERVERS" + if ob_version.startswith("3") or ob_version.startswith("2") or ob_version.startswith("1"): + sql = "select SVR_IP, SVR_PORT, ZONE, BUILD_VERSION from oceanbase.__all_server" + res = obConnetcor.execute_sql(sql) + if len(res) == 0: + raise Exception("Failed to get the node from sql [{0}], " "please check whether the --config option correct!!!".format(sql)) + host_info_list = [] + for row in res: + host_info = OrderedDict() + host_info["ip"] = row[0] + self.stdio.verbose("get host info: %s", host_info) + host_info_list.append(host_info) + config_data_new = copy(config_data) + if 'servers' in config_data_new['obcluster']: + if not isinstance(config_data_new['obcluster']['servers'], dict): + config_data_new['obcluster']['servers'] = {} + if 'nodes' not in config_data_new['obcluster']['servers'] or not isinstance(config_data_new['obcluster']['servers']['nodes'], list): + config_data_new['obcluster']['servers']['nodes'] = [] + for item in host_info_list: + ip = item['ip'] + config_data_new['obcluster']['servers']['nodes'].append({'ip': ip}) + config.update_config_data(config_data_new) + def get_namespace(self, spacename): if spacename in self.namespaces: namespace = self.namespaces[spacename]