Skip to content

Commit

Permalink
add CODE_TAR for trainer worker (#698)
Browse files Browse the repository at this point in the history
* add code_tar to .sh

* complete test

* fix

Co-authored-by: xiangyuxuan.prs <[email protected]>
  • Loading branch information
Ssskrilex and xiangyuxuan.prs authored Mar 19, 2021
1 parent 8826853 commit f0b0047
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
4 changes: 3 additions & 1 deletion deploy/scripts/env_to_args.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ pull_code() {
wget $1 -O code.tar.gz
elif [[ $1 == "oss://"* ]]; then
python -c "import tensorflow as tf; import tensorflow_io; open('code.tar.gz', 'wb').write(tf.io.gfile.GFile('$1', 'rb').read())"
elif [[ $1 == "base64://"* ]]; then
python -c "import base64; f = open('code.tar.gz', 'wb'); f.write(base64.b64decode('$1'[9:])); f.close()"
else
cp $1 code.tar.gz
fi
fi
tar -zxvf code.tar.gz
cd $cwd
}
7 changes: 6 additions & 1 deletion deploy/scripts/trainer/run_trainer_worker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ for i in "${WORKER_GROUPS[@]}"; do
done
fi

pull_code ${CODE_KEY} $PWD
if [[ -n "${CODE_KEY}" ]]; then
pull_code ${CODE_KEY} $PWD
else
pull_code ${CODE_TAR} $PWD
fi

cd ${ROLE}

mode=$(normalize_env_to_args "--mode" "$MODE")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

# coding: utf-8
import unittest

from fedlearner_webconsole.job.yaml_formatter import format_yaml
import tarfile
import base64
from io import BytesIO
from fedlearner_webconsole.job.yaml_formatter import format_yaml, code_dict_encode


class YamlFormatterTest(unittest.TestCase):
Expand Down Expand Up @@ -62,6 +64,22 @@ def test_format_yaml_unknown_ph(self):
format_yaml('$x.y is ${i.j}', x=x)
self.assertEqual(str(cm.exception), 'Unknown placeholder: i.j')

def test_encode_code(self):
test_data = {'test/a.py': 'awefawefawefawefwaef',
'test1/b.py': 'asdfasd',
'c.py': '',
'test/d.py': 'asdf'}
code_base64 = code_dict_encode(test_data)
code_dict = {}
if code_base64.startswith('base64://'):
tar_binary = BytesIO(base64.b64decode(code_base64[9:]))
with tarfile.open(fileobj=tar_binary) as tar:
for file in tar.getmembers():
code_dict[file.name] = str(tar.extractfile(file).read(),
encoding='utf-8')
self.assertEqual(code_dict, test_data)



if __name__ == '__main__':
unittest.main()

0 comments on commit f0b0047

Please sign in to comment.