-
Notifications
You must be signed in to change notification settings - Fork 326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] A bug in the initialize_ub function. #1170
Comments
@denera Could you take a look at it? |
@wangzihe1996 It's easy to get a list of network interfaces from That's why we set up the That said, it would still be better if TE checks |
@denera I'm glad you have made these changes which can give more informations to the developers such as me. Meanwhile, I have another question that which one should I choose if I have many RDMA network cards. We can assume that they are named |
@wangzihe1996 The correct network interface is the one that returns the same hostname on processes/ranks that map to GPUs on the same physical node. We can't make any general assumptions or guesses about this in TE, but your specific case looks like each node connecting to different groups of nodes via different RDMA network interfaces. If that's true, then any one of the |
@denera Thank you for your reply. I understand more details of this. |
I find that TransformerEngine has supported the tensor parallelism (TP) communication overlap without the dependency of MPI. Therefore, I tried to use the tensor parallelism (TP) communication overlap in torchrun method. In this process, I found a bug in the initialize_ub function.
I ran my code but get the error as follows.
I found the code as follows.
TransformerEngine/transformer_engine/pytorch/module/base.py
Lines 130 to 149 in bdea56f
The function tries to get the IP address as the hostname if the ifname is not None. And the ifname is obtained from the environment variables NVTE_UB_SOCKET_IFNAME, NCCL_SOCKET_IFNAME, and GLOO_SOCKET_IFNAME.
In my environment, the NCCL_SOCKET_IFNAME is set to be
eth
. In fact, the machine has many network cards namedeth0
,eth1
,eth2
, and so on. But there is not a network card calledeth
. I try to useifname = eth0
and the code above can run successfully.I checked the NCCL documentation about the environment variables. It shows that when the value of NCCL_SOCKET_IFNAME is
eth
, it will use all interfaces starting with eth, e.g. eth0, eth1.So I think the code need to use other methods, such as
psutil
ornetifaces
, to get the name of the network cards when NCCL_SOCKET_IFNAME iseth
.This is my test code:
The text was updated successfully, but these errors were encountered: