diff --git a/go.mod b/go.mod index e292f1728..ec315f570 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,6 @@ module github.com/skycoin/skywire go 1.12 require ( - github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect - github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 // indirect github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/creack/pty v1.1.7 github.com/go-chi/chi v4.0.2+incompatible @@ -15,18 +13,19 @@ require ( github.com/mitchellh/go-homedir v1.1.0 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/pkg/errors v0.8.1 github.com/pkg/profile v1.3.0 github.com/prometheus/client_golang v1.1.0 - github.com/prometheus/common v0.6.0 + github.com/prometheus/common v0.7.0 github.com/sirupsen/logrus v1.4.2 - github.com/skycoin/dmsg v0.0.0-20190904181013-b781e3cbebc6 + github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f github.com/skycoin/skycoin v0.26.0 github.com/spf13/cobra v0.0.5 github.com/stretchr/testify v1.4.0 + github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 // indirect go.etcd.io/bbolt v1.3.3 - golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 - golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 + golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 + golang.org/x/net v0.0.0-20190916140828-c8589233b77d ) -// Uncomment for tests with alternate branches of 'dmsg' -//replace github.com/skycoin/dmsg => ../dmsg +replace github.com/skycoin/dmsg => ../dmsg diff --git a/go.sum b/go.sum index 60f365b04..ba051af7c 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,7 @@ github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMo github.com/go-chi/chi v4.0.2+incompatible h1:maB6vn6FqCxrpz4FqWdh4+lwpyZIQS7YEAUcHlgXVRs= github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -57,10 +58,13 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= @@ -77,6 +81,8 @@ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3Rllmb github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.3.0 h1:OQIvuDgm00gWVWGTf4m4mCt6W1/0YqU7Ntg0mySWgaI= github.com/pkg/profile v1.3.0/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -89,22 +95,27 @@ github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1: github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/common v0.6.0 h1:kRhiuYSXR3+uv2IbVbZhUxK5zVD/2pp3Gd2PpvPkpEo= github.com/prometheus/common v0.6.0/go.mod h1:eBmuwkDJBwy6iBfxCBob6t6dR6ENT/y+J+Zk0j9GMYc= +github.com/prometheus/common v0.7.0 h1:L+1lyG48J1zAQXA3RBX/nG/B3gjlHq0zTt2tlbJLyCY= +github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.3 h1:CTwfnzjQ+8dS6MhHHu4YswVAD99sL2wjPqP+VkURmKE= github.com/prometheus/procfs v0.0.3/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/skycoin/dmsg v0.0.0-20190904181013-b781e3cbebc6 h1:YwSyQXUyG/EFp3xCGMkOldgQNpw8XLfmocQND4/Y3aw= -github.com/skycoin/dmsg v0.0.0-20190904181013-b781e3cbebc6/go.mod h1:obZYZp8eKR7Xqz+KNhJdUE6Gvp6rEXbDO8YTlW2YXgU= +github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f h1:WWjaxOXoj6oYelm67MNtJbg51HQALjKAyhs2WAHgpZs= +github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f/go.mod h1:obZYZp8eKR7Xqz+KNhJdUE6Gvp6rEXbDO8YTlW2YXgU= +github.com/skycoin/skycoin v0.25.1/go.mod h1:78nHjQzd8KG0jJJVL/j0xMmrihXi70ti63fh8vXScJw= github.com/skycoin/skycoin v0.26.0 h1:xDxe2r8AclMntZ550Y/vUQgwgLtwrf9Wu5UYiYcN5/o= github.com/skycoin/skycoin v0.26.0/go.mod h1:78nHjQzd8KG0jJJVL/j0xMmrihXi70ti63fh8vXScJw= +github.com/skycoin/skywire v0.1.1/go.mod h1:jDuUgTG20jhiBI6Trpayj0my6xhdS+ejEO9gTSM+C/E= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= @@ -112,28 +123,34 @@ github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 h1:Xim2mBRFdXzXmKRO8DJg/FJtn/8Fj9NOEpO6+WuMPmk= +github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5/go.mod h1:ppEjwdhyy7Y31EnHRDm1JkChoC7LXIJ7Ex0VYLWtZtQ= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM= -golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM= -golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190916140828-c8589233b77d h1:mCMDWKhNO37A7GAhOpHPbIw1cjd0V86kX1/WA9c7FZ8= +golang.org/x/net v0.0.0-20190916140828-c8589233b77d/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -150,6 +167,7 @@ golang.org/x/sys v0.0.0-20190801041406-cbf593c0f2f3/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181112210238-4b1f3b6b1646/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= diff --git a/pkg/app2/client.go b/pkg/app2/client.go new file mode 100644 index 000000000..3d8a5f595 --- /dev/null +++ b/pkg/app2/client.go @@ -0,0 +1,144 @@ +package app2 + +import ( + "net" + + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +// Client is used by skywire apps. +type Client struct { + log *logging.Logger + pk cipher.PubKey + pid ProcID + rpc RPCClient + lm *idManager // contains listeners associated with their IDs + cm *idManager // contains connections associated with their IDs +} + +// NewClient creates a new `Client`. The `Client` needs to be provided with: +// - log: logger instance +// - localPK: The local public key of the parent skywire visor. +// - pid: The procID assigned for the process that Client is being used by. +// - rpc: RPC client to communicate with the server. +func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc RPCClient) *Client { + return &Client{ + log: log, + pk: localPK, + pid: pid, + rpc: rpc, + lm: newIDManager(), + cm: newIDManager(), + } +} + +// Dial dials the remote node using `remote`. +func (c *Client) Dial(remote network.Addr) (net.Conn, error) { + connID, localPort, err := c.rpc.Dial(remote) + if err != nil { + return nil, err + } + + conn := &Conn{ + id: connID, + rpc: c.rpc, + local: network.Addr{ + Net: remote.Net, + PubKey: c.pk, + Port: localPort, + }, + remote: remote, + } + + free, err := c.cm.add(connID, conn) + if err != nil { + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("error closing conn") + } + + return nil, err + } + + conn.freeConn = free + + return conn, nil +} + +// Listen listens on the specified `port` for the incoming connections. +func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) { + local := network.Addr{ + Net: n, + PubKey: c.pk, + Port: port, + } + + lisID, err := c.rpc.Listen(local) + if err != nil { + return nil, err + } + + listener := &Listener{ + log: c.log, + id: lisID, + rpc: c.rpc, + addr: local, + cm: newIDManager(), + } + + freeLis, err := c.lm.add(lisID, listener) + if err != nil { + if err := listener.Close(); err != nil { + c.log.WithError(err).Error("error closing listener") + } + + return nil, err + } + + listener.freeLis = freeLis + + return listener, nil +} + +// Close closes client/server communication entirely. It closes all open +// listeners and connections. +func (c *Client) Close() { + var listeners []net.Listener + c.lm.doRange(func(_ uint16, v interface{}) bool { + lis, err := assertListener(v) + if err != nil { + c.log.Error(err) + return true + } + + listeners = append(listeners, lis) + return true + }) + + var conns []net.Conn + c.cm.doRange(func(_ uint16, v interface{}) bool { + conn, err := assertConn(v) + if err != nil { + c.log.Error(err) + return true + } + + conns = append(conns, conn) + return true + }) + + for _, lis := range listeners { + if err := lis.Close(); err != nil { + c.log.WithError(err).Error("error closing listener") + } + } + + for _, conn := range conns { + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("error closing conn") + } + } +} diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go new file mode 100644 index 000000000..56e571cbf --- /dev/null +++ b/pkg/app2/client_test.go @@ -0,0 +1,280 @@ +package app2 + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestClient_Dial(t *testing.T) { + l := logging.MustGetLogger("app2_client") + localPK, _ := cipher.GenerateKeyPair() + pid := ProcID(1) + + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(120) + remote := network.Addr{ + Net: network.TypeDMSG, + PubKey: remotePK, + Port: remotePort, + } + + t.Run("ok", func(t *testing.T) { + dialConnID := uint16(1) + dialLocalPort := routing.Port(1) + var dialErr error + + rpc := &MockRPCClient{} + rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) + + cl := NewClient(l, localPK, pid, rpc) + + wantConn := &Conn{ + id: dialConnID, + rpc: rpc, + local: network.Addr{ + Net: remote.Net, + PubKey: localPK, + Port: dialLocalPort, + }, + remote: remote, + } + + conn, err := cl.Dial(remote) + require.NoError(t, err) + + appConn, ok := conn.(*Conn) + require.True(t, ok) + + require.Equal(t, wantConn.id, appConn.id) + require.Equal(t, wantConn.rpc, appConn.rpc) + require.Equal(t, wantConn.local, appConn.local) + require.Equal(t, wantConn.remote, appConn.remote) + require.NotNil(t, appConn.freeConn) + + cmConnIfc, ok := cl.cm.values[appConn.id] + require.True(t, ok) + require.NotNil(t, cmConnIfc) + + cmConn, ok := cmConnIfc.(*Conn) + require.True(t, ok) + require.NotNil(t, cmConn.freeConn) + }) + + t.Run("conn already exists", func(t *testing.T) { + dialConnID := uint16(1) + dialLocalPort := routing.Port(1) + var dialErr error + + var closeErr error + + rpc := &MockRPCClient{} + rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) + rpc.On("CloseConn", dialConnID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.cm.add(dialConnID, nil) + require.NoError(t, err) + + conn, err := cl.Dial(remote) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) + }) + + t.Run("conn already exists, conn closed with error", func(t *testing.T) { + dialConnID := uint16(1) + dialLocalPort := routing.Port(1) + var dialErr error + + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) + rpc.On("CloseConn", dialConnID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.cm.add(dialConnID, nil) + require.NoError(t, err) + + conn, err := cl.Dial(remote) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) + }) + + t.Run("dial error", func(t *testing.T) { + dialErr := errors.New("dial error") + + rpc := &MockRPCClient{} + rpc.On("Dial", remote).Return(uint16(0), routing.Port(0), dialErr) + + cl := NewClient(l, localPK, pid, rpc) + + conn, err := cl.Dial(remote) + require.Equal(t, dialErr, err) + require.Nil(t, conn) + }) +} + +func TestClient_Listen(t *testing.T) { + l := logging.MustGetLogger("app2_client") + localPK, _ := cipher.GenerateKeyPair() + pid := ProcID(1) + + port := routing.Port(1) + local := network.Addr{ + Net: network.TypeDMSG, + PubKey: localPK, + Port: port, + } + + t.Run("ok", func(t *testing.T) { + listenLisID := uint16(1) + var listenErr error + + rpc := &MockRPCClient{} + rpc.On("Listen", local).Return(listenLisID, listenErr) + + cl := NewClient(l, localPK, pid, rpc) + + wantListener := &Listener{ + id: listenLisID, + rpc: rpc, + addr: local, + } + + listener, err := cl.Listen(network.TypeDMSG, port) + require.Nil(t, err) + + appListener, ok := listener.(*Listener) + require.True(t, ok) + + require.Equal(t, wantListener.id, appListener.id) + require.Equal(t, wantListener.rpc, appListener.rpc) + require.Equal(t, wantListener.addr, appListener.addr) + require.NotNil(t, appListener.freeLis) + }) + + t.Run("listener already exists", func(t *testing.T) { + listenLisID := uint16(1) + var listenErr error + + var closeErr error + + rpc := &MockRPCClient{} + rpc.On("Listen", local).Return(listenLisID, listenErr) + rpc.On("CloseListener", listenLisID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.lm.add(listenLisID, nil) + require.NoError(t, err) + + listener, err := cl.Listen(network.TypeDMSG, port) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, listener) + }) + + t.Run("listener already exists, listener closed with error", func(t *testing.T) { + listenLisID := uint16(1) + var listenErr error + + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("Listen", local).Return(listenLisID, listenErr) + rpc.On("CloseListener", listenLisID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.lm.add(listenLisID, nil) + require.NoError(t, err) + + listener, err := cl.Listen(network.TypeDMSG, port) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, listener) + }) + + t.Run("listen error", func(t *testing.T) { + listenErr := errors.New("listen error") + + rpc := &MockRPCClient{} + rpc.On("Listen", local).Return(uint16(0), listenErr) + + cl := NewClient(l, localPK, pid, rpc) + + listener, err := cl.Listen(network.TypeDMSG, port) + require.Equal(t, listenErr, err) + require.Nil(t, listener) + }) +} + +func TestClient_Close(t *testing.T) { + l := logging.MustGetLogger("app2_client") + localPK, _ := cipher.GenerateKeyPair() + pid := ProcID(1) + + var closeNoErr error + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + + lisID1 := uint16(1) + lisID2 := uint16(2) + + rpc.On("CloseListener", lisID1).Return(closeNoErr) + rpc.On("CloseListener", lisID2).Return(closeErr) + + lm := newIDManager() + + lis1 := &Listener{id: lisID1, rpc: rpc, cm: newIDManager()} + freeLis1, err := lm.add(lisID1, lis1) + require.NoError(t, err) + lis1.freeLis = freeLis1 + + lis2 := &Listener{id: lisID2, rpc: rpc, cm: newIDManager()} + freeLis2, err := lm.add(lisID2, lis2) + require.NoError(t, err) + lis2.freeLis = freeLis2 + + connID1 := uint16(1) + connID2 := uint16(2) + + rpc.On("CloseConn", connID1).Return(closeNoErr) + rpc.On("CloseConn", connID2).Return(closeErr) + + cm := newIDManager() + + conn1 := &Conn{id: connID1, rpc: rpc} + freeConn1, err := cm.add(connID1, conn1) + require.NoError(t, err) + conn1.freeConn = freeConn1 + + conn2 := &Conn{id: connID2, rpc: rpc} + freeConn2, err := cm.add(connID2, conn2) + require.NoError(t, err) + conn2.freeConn = freeConn2 + + cl := NewClient(l, localPK, pid, rpc) + cl.cm = cm + cl.lm = lm + + cl.Close() + + _, ok := lm.values[lisID1] + require.False(t, ok) + _, ok = lm.values[lisID2] + require.False(t, ok) + + _, ok = cm.values[connID1] + require.False(t, ok) + _, ok = cm.values[connID2] + require.False(t, ok) +} diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go new file mode 100644 index 000000000..e6473a9ea --- /dev/null +++ b/pkg/app2/conn.go @@ -0,0 +1,61 @@ +package app2 + +import ( + "net" + "time" + + "github.com/skycoin/skywire/pkg/app2/network" +) + +// Conn is a connection from app client to the server. +// Implements `net.Conn`. +type Conn struct { + id uint16 + rpc RPCClient + local network.Addr + remote network.Addr + freeConn func() +} + +func (c *Conn) Read(b []byte) (int, error) { + n, err := c.rpc.Read(c.id, b) + if err != nil { + return 0, err + } + + return n, err +} + +func (c *Conn) Write(b []byte) (int, error) { + return c.rpc.Write(c.id, b) +} + +func (c *Conn) Close() error { + defer func() { + if c.freeConn != nil { + c.freeConn() + } + }() + + return c.rpc.CloseConn(c.id) +} + +func (c *Conn) LocalAddr() net.Addr { + return c.local +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.remote +} + +func (c *Conn) SetDeadline(t time.Time) error { + return errMethodNotImplemented +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return errMethodNotImplemented +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return errMethodNotImplemented +} diff --git a/pkg/app2/conn_test.go b/pkg/app2/conn_test.go new file mode 100644 index 000000000..2b185065c --- /dev/null +++ b/pkg/app2/conn_test.go @@ -0,0 +1,116 @@ +package app2 + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConn_Read(t *testing.T) { + connID := uint16(1) + + tt := []struct { + name string + readBuff []byte + readN int + readErr error + }{ + { + name: "ok", + readBuff: make([]byte, 10), + readN: 2, + }, + { + name: "read error", + readBuff: make([]byte, 10), + readErr: errors.New("read error"), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockRPCClient{} + rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Read(tc.readBuff) + require.Equal(t, tc.readErr, err) + require.Equal(t, tc.readN, n) + }) + } +} + +func TestConn_Write(t *testing.T) { + connID := uint16(1) + + tt := []struct { + name string + writeBuff []byte + writeN int + writeErr error + }{ + { + name: "ok", + writeBuff: make([]byte, 10), + writeN: 2, + }, + { + name: "write error", + writeBuff: make([]byte, 10), + writeErr: errors.New("write error"), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockRPCClient{} + rpc.On("Write", connID, tc.writeBuff).Return(tc.writeN, tc.writeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Write(tc.writeBuff) + require.Equal(t, tc.writeErr, err) + require.Equal(t, tc.writeN, n) + }) + } +} + +func TestConn_Close(t *testing.T) { + connID := uint16(1) + + tt := []struct { + name string + closeErr error + }{ + { + name: "ok", + }, + { + name: "close error", + closeErr: errors.New("close error"), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockRPCClient{} + rpc.On("CloseConn", connID).Return(tc.closeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + err := conn.Close() + require.Equal(t, tc.closeErr, err) + }) + } +} diff --git a/pkg/app2/doc.go b/pkg/app2/doc.go new file mode 100644 index 000000000..ff4dfd1b3 --- /dev/null +++ b/pkg/app2/doc.go @@ -0,0 +1,4 @@ +// Package app2 provides facilities to establish communication +// between a visor node and a skywire application. Intended to +// replace the original `app` module. +package app2 diff --git a/pkg/app2/errors.go b/pkg/app2/errors.go new file mode 100644 index 000000000..88653a613 --- /dev/null +++ b/pkg/app2/errors.go @@ -0,0 +1,14 @@ +package app2 + +import "github.com/pkg/errors" + +var ( + // ErrPortAlreadyBound is being returned when trying to bind to the port + // which is already bound to. + ErrPortAlreadyBound = errors.New("port is already bound") +) + +var ( + // errMethodNotImplemented serves as a return value for non-implemented funcs (stubs). + errMethodNotImplemented = errors.New("method not implemented") +) diff --git a/pkg/app2/id_manager.go b/pkg/app2/id_manager.go new file mode 100644 index 000000000..6f087f1b3 --- /dev/null +++ b/pkg/app2/id_manager.go @@ -0,0 +1,142 @@ +package app2 + +import ( + "fmt" + "sync" + + "github.com/pkg/errors" +) + +var ( + errNoMoreAvailableValues = errors.New("no more available values") + errValueAlreadyExists = errors.New("value already exists") +) + +// idManager manages allows to store and retrieve arbitrary values +// associated with the `uint16` key in a thread-safe manner. +// Provides method to generate key. +type idManager struct { + values map[uint16]interface{} + mx sync.RWMutex + lstID uint16 +} + +// newIDManager constructs new `idManager`. +func newIDManager() *idManager { + return &idManager{ + values: make(map[uint16]interface{}), + } +} + +// `reserveNextID` reserves next free slot for the value and returns the id for it. +func (m *idManager) reserveNextID() (id *uint16, free func(), err error) { + m.mx.Lock() + + nxtID := m.lstID + 1 + for ; nxtID != m.lstID; nxtID++ { + if _, ok := m.values[nxtID]; !ok { + break + } + } + + if nxtID == m.lstID { + m.mx.Unlock() + return nil, nil, errNoMoreAvailableValues + } + + m.values[nxtID] = nil + m.lstID = nxtID + + m.mx.Unlock() + return &nxtID, m.constructFreeFunc(nxtID), nil +} + +// pop removes value specified by `id` from the idManager instance and +// returns it. +func (m *idManager) pop(id uint16) (interface{}, error) { + m.mx.Lock() + v, ok := m.values[id] + if !ok { + m.mx.Unlock() + return nil, fmt.Errorf("no value with id %d", id) + } + + if v == nil { + m.mx.Unlock() + return nil, fmt.Errorf("value with id %d is not set", id) + } + + delete(m.values, id) + + m.mx.Unlock() + return v, nil +} + +// add adds the new value `v` associated with `id`. +func (m *idManager) add(id uint16, v interface{}) (free func(), err error) { + m.mx.Lock() + + if _, ok := m.values[id]; ok { + m.mx.Unlock() + return nil, errValueAlreadyExists + } + + m.values[id] = v + + m.mx.Unlock() + return m.constructFreeFunc(id), nil +} + +// set sets value `v` associated with `id`. +func (m *idManager) set(id uint16, v interface{}) error { + m.mx.Lock() + + l, ok := m.values[id] + if !ok { + m.mx.Unlock() + return errors.New("id is not reserved") + } else { + if l != nil { + m.mx.Unlock() + return errValueAlreadyExists + } + } + + m.values[id] = v + + m.mx.Unlock() + return nil +} + +// get gets the value associated with the `id`. +func (m *idManager) get(id uint16) (interface{}, bool) { + m.mx.RLock() + lis, ok := m.values[id] + m.mx.RUnlock() + if lis == nil { + return nil, false + } + return lis, ok +} + +// doRange performs range over the manager contents. Loop stops when +// `next` returns false. +func (m *idManager) doRange(next func(id uint16, v interface{}) bool) { + m.mx.RLock() + for id, v := range m.values { + if !next(id, v) { + break + } + } + m.mx.RUnlock() +} + +// constructFreeFunc constructs new func responsible for clearing +// a slot with the specified `id`. +func (m *idManager) constructFreeFunc(id uint16) func() { + return func() { + m.mx.Lock() + delete(m.values, id) + m.mx.Unlock() + } +} diff --git a/pkg/app2/id_manager_test.go b/pkg/app2/id_manager_test.go new file mode 100644 index 000000000..20513ea45 --- /dev/null +++ b/pkg/app2/id_manager_test.go @@ -0,0 +1,415 @@ +package app2 + +import ( + "math" + "sort" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIDManager_ReserveNextID(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newIDManager() + + nextID, free, err := m.reserveNextID() + require.NoError(t, err) + require.NotNil(t, free) + v, ok := m.values[*nextID] + require.True(t, ok) + require.Nil(t, v) + require.Equal(t, *nextID, uint16(1)) + require.Equal(t, *nextID, m.lstID) + + nextID, free, err = m.reserveNextID() + require.NoError(t, err) + require.NotNil(t, free) + v, ok = m.values[*nextID] + require.True(t, ok) + require.Nil(t, v) + require.Equal(t, *nextID, uint16(2)) + require.Equal(t, *nextID, m.lstID) + }) + + t.Run("call on full manager", func(t *testing.T) { + m := newIDManager() + for i := uint16(0); i < math.MaxUint16; i++ { + m.values[i] = nil + } + m.values[math.MaxUint16] = nil + + _, _, err := m.reserveNextID() + require.Error(t, err) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newIDManager() + + valsToReserve := 10000 + + errs := make(chan error) + for i := 0; i < valsToReserve; i++ { + go func() { + _, _, err := m.reserveNextID() + errs <- err + }() + } + + for i := 0; i < valsToReserve; i++ { + require.NoError(t, <-errs) + } + close(errs) + + require.Equal(t, m.lstID, uint16(valsToReserve)) + for i := uint16(1); i < uint16(valsToReserve); i++ { + v, ok := m.values[i] + require.True(t, ok) + require.Nil(t, v) + } + }) +} + +func TestIDManager_Pop(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newIDManager() + + v := "value" + + m.values[1] = v + + gotV, err := m.pop(1) + require.NoError(t, err) + require.NotNil(t, gotV) + require.Equal(t, gotV, v) + + _, ok := m.values[1] + require.False(t, ok) + }) + + t.Run("no value", func(t *testing.T) { + m := newIDManager() + + _, err := m.pop(1) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no value")) + }) + + t.Run("value not set", func(t *testing.T) { + m := newIDManager() + + m.values[1] = nil + + _, err := m.pop(1) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "is not set")) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newIDManager() + + m.values[1] = "value" + + concurrency := 1000 + errs := make(chan error, concurrency) + for i := uint16(0); i < uint16(concurrency); i++ { + go func() { + _, err := m.pop(1) + errs <- err + }() + } + + errsCount := 0 + for i := 0; i < concurrency; i++ { + err := <-errs + if err != nil { + errsCount++ + } + } + close(errs) + require.Equal(t, errsCount, concurrency-1) + + _, ok := m.values[1] + require.False(t, ok) + }) +} + +func TestIDManager_Add(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newIDManager() + + id := uint16(1) + v := "value" + + free, err := m.add(id, v) + require.Nil(t, err) + require.NotNil(t, free) + + gotV, ok := m.values[id] + require.True(t, ok) + require.Equal(t, gotV, v) + + v2 := "value2" + + free, err = m.add(id, v2) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, free) + + gotV, ok = m.values[id] + require.True(t, ok) + require.Equal(t, gotV, v) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newIDManager() + + id := uint16(1) + + concurrency := 1000 + + addV := make(chan int) + errs := make(chan error) + for i := 0; i < concurrency; i++ { + go func(v int) { + _, err := m.add(id, v) + errs <- err + if err == nil { + addV <- v + } + }(i) + } + + errsCount := 0 + for i := 0; i < concurrency; i++ { + if err := <-errs; err != nil { + errsCount++ + } + } + close(errs) + + v := <-addV + close(addV) + + require.Equal(t, concurrency-1, errsCount) + + gotV, ok := m.values[id] + require.True(t, ok) + require.Equal(t, gotV, v) + }) +} + +func TestIDManager_Set(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newIDManager() + + nextID, _, err := m.reserveNextID() + require.NoError(t, err) + + v := "value" + + err = m.set(*nextID, v) + require.NoError(t, err) + gotV, ok := m.values[*nextID] + require.True(t, ok) + require.Equal(t, gotV, v) + }) + + t.Run("id is not reserved", func(t *testing.T) { + m := newIDManager() + + err := m.set(1, "value") + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "not reserved")) + + _, ok := m.values[1] + require.False(t, ok) + }) + + t.Run("value already exists", func(t *testing.T) { + m := newIDManager() + + v := "value" + + m.values[1] = v + + err := m.set(1, "value2") + require.Error(t, err) + gotV, ok := m.values[1] + require.True(t, ok) + require.Equal(t, gotV, v) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newIDManager() + + concurrency := 1000 + + nextIDPtr, _, err := m.reserveNextID() + require.NoError(t, err) + + nextID := *nextIDPtr + + errs := make(chan error) + setV := make(chan int) + for i := 0; i < concurrency; i++ { + go func(v int) { + err := m.set(nextID, v) + errs <- err + if err == nil { + setV <- v + } + }(i) + } + + errsCount := 0 + for i := 0; i < concurrency; i++ { + err := <-errs + if err != nil { + errsCount++ + } + } + close(errs) + + v := <-setV + close(setV) + + require.Equal(t, concurrency-1, errsCount) + + gotV, ok := m.values[nextID] + require.True(t, ok) + require.Equal(t, gotV, v) + }) +} + +func TestIDManager_Get(t *testing.T) { + prepManagerWithVal := func(v interface{}) (*idManager, uint16) { + m := newIDManager() + + nextID, _, err := m.reserveNextID() + require.NoError(t, err) + + err = m.set(*nextID, v) + require.NoError(t, err) + + return m, *nextID + } + + t.Run("simple call", func(t *testing.T) { + v := "value" + + m, id := prepManagerWithVal(v) + + gotV, ok := m.get(id) + require.True(t, ok) + require.Equal(t, gotV, v) + + _, ok = m.get(100) + require.False(t, ok) + + m.values[2] = nil + gotV, ok = m.get(2) + require.False(t, ok) + require.Nil(t, gotV) + }) + + t.Run("concurrent run", func(t *testing.T) { + v := "value" + + m, id := prepManagerWithVal(v) + + concurrency := 1000 + type getRes struct { + v interface{} + ok bool + } + res := make(chan getRes) + for i := 0; i < concurrency; i++ { + go func() { + val, ok := m.get(id) + res <- getRes{ + v: val, + ok: ok, + } + }() + } + + for i := 0; i < concurrency; i++ { + r := <-res + require.True(t, r.ok) + require.Equal(t, r.v, v) + } + close(res) + }) +} + +func TestIDManager_DoRange(t *testing.T) { + m := newIDManager() + + valsCount := 5 + + vals := make([]int, 0, valsCount) + for i := 0; i < valsCount; i++ { + vals = append(vals, i) + } + + for i, v := range vals { + _, err := m.add(uint16(i), v) + require.NoError(t, err) + } + + // run full range + gotVals := make([]int, 0, valsCount) + m.doRange(func(_ uint16, v interface{}) bool { + val, ok := v.(int) + require.True(t, ok) + + gotVals = append(gotVals, val) + + return true + }) + sort.Ints(gotVals) + require.Equal(t, gotVals, vals) + + // run part range + var gotVal int + gotValsCount := 0 + m.doRange(func(_ uint16, v interface{}) bool { + if gotValsCount == 1 { + return false + } + + val, ok := v.(int) + require.True(t, ok) + + gotVal = val + + gotValsCount++ + + return true + }) + + found := false + for _, v := range vals { + if v == gotVal { + found = true + } + } + require.True(t, found) +} + +func TestIDManager_ConstructFreeFunc(t *testing.T) { + m := newIDManager() + + id := uint16(1) + v := "value" + + free, err := m.add(id, v) + require.NoError(t, err) + require.NotNil(t, free) + + free() + + gotV, ok := m.values[id] + require.False(t, ok) + require.Nil(t, gotV) +} diff --git a/pkg/app2/id_manager_util.go b/pkg/app2/id_manager_util.go new file mode 100644 index 000000000..174b29330 --- /dev/null +++ b/pkg/app2/id_manager_util.go @@ -0,0 +1,27 @@ +package app2 + +import ( + "net" + + "github.com/pkg/errors" +) + +// assertListener asserts that `v` is of type `net.Listener`. +func assertListener(v interface{}) (net.Listener, error) { + lis, ok := v.(net.Listener) + if !ok { + return nil, errors.New("wrong type of value stored for listener") + } + + return lis, nil +} + +// assertConn asserts that `v` is of type `net.Conn`. +func assertConn(v interface{}) (net.Conn, error) { + conn, ok := v.(net.Conn) + if !ok { + return nil, errors.New("wrong type of value stored for conn") + } + + return conn, nil +} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go new file mode 100644 index 000000000..9b20399d4 --- /dev/null +++ b/pkg/app2/listener.go @@ -0,0 +1,79 @@ +package app2 + +import ( + "net" + + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/app2/network" +) + +// Listener is a listener for app server connections. +// Implements `net.Listener`. +type Listener struct { + log *logging.Logger + id uint16 + rpc RPCClient + addr network.Addr + cm *idManager // contains conns associated with their IDs + freeLis func() +} + +func (l *Listener) Accept() (net.Conn, error) { + connID, remote, err := l.rpc.Accept(l.id) + if err != nil { + return nil, err + } + + conn := &Conn{ + id: connID, + rpc: l.rpc, + local: l.addr, + remote: remote, + } + + free, err := l.cm.add(connID, conn) + if err != nil { + if err := conn.Close(); err != nil { + l.log.WithError(err).Error("error closing listener") + } + + return nil, err + } + + conn.freeConn = free + + return conn, nil +} + +func (l *Listener) Close() error { + defer func() { + if l.freeLis != nil { + l.freeLis() + } + + var conns []net.Conn + l.cm.doRange(func(_ uint16, v interface{}) bool { + conn, err := assertConn(v) + if err != nil { + l.log.Error(err) + return true + } + + conns = append(conns, conn) + return true + }) + + for _, conn := range conns { + if err := conn.Close(); err != nil { + l.log.WithError(err).Error("error closing listener") + } + } + }() + + return l.rpc.CloseListener(l.id) +} + +func (l *Listener) Addr() net.Addr { + return l.addr +} diff --git a/pkg/app2/listener_test.go b/pkg/app2/listener_test.go new file mode 100644 index 000000000..9ff53cf7b --- /dev/null +++ b/pkg/app2/listener_test.go @@ -0,0 +1,236 @@ +package app2 + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestListener_Accept(t *testing.T) { + l := logging.MustGetLogger("app2_listener") + + lisID := uint16(1) + localPK, _ := cipher.GenerateKeyPair() + local := network.Addr{ + Net: network.TypeDMSG, + PubKey: localPK, + Port: routing.Port(100), + } + + t.Run("ok", func(t *testing.T) { + acceptConnID := uint16(1) + acceptRemotePK, _ := cipher.GenerateKeyPair() + acceptRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: acceptRemotePK, + Port: routing.Port(100), + } + var acceptErr error + + rpc := &MockRPCClient{} + rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + wantConn := &Conn{ + id: acceptConnID, + rpc: rpc, + local: local, + remote: acceptRemote, + } + + conn, err := lis.Accept() + require.NoError(t, err) + + appConn, ok := conn.(*Conn) + require.True(t, ok) + require.Equal(t, wantConn.id, appConn.id) + require.Equal(t, wantConn.rpc, appConn.rpc) + require.Equal(t, wantConn.local, appConn.local) + require.Equal(t, wantConn.remote, appConn.remote) + require.NotNil(t, appConn.freeConn) + + connIfc, ok := lis.cm.values[acceptConnID] + require.True(t, ok) + + appConn, ok = connIfc.(*Conn) + require.True(t, ok) + require.NotNil(t, appConn.freeConn) + }) + + t.Run("conn already exists", func(t *testing.T) { + acceptConnID := uint16(1) + acceptRemotePK, _ := cipher.GenerateKeyPair() + acceptRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: acceptRemotePK, + Port: routing.Port(100), + } + var acceptErr error + + var closeErr error + + rpc := &MockRPCClient{} + rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) + rpc.On("CloseConn", acceptConnID).Return(closeErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + lis.cm.values[acceptConnID] = nil + + conn, err := lis.Accept() + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) + }) + + t.Run("conn already exists, conn closed with error", func(t *testing.T) { + acceptConnID := uint16(1) + acceptRemotePK, _ := cipher.GenerateKeyPair() + acceptRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: acceptRemotePK, + Port: routing.Port(100), + } + var acceptErr error + + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) + rpc.On("CloseConn", acceptConnID).Return(closeErr) + + lis := &Listener{ + log: l, + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + lis.cm.values[acceptConnID] = nil + + conn, err := lis.Accept() + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) + }) + + t.Run("accept error", func(t *testing.T) { + acceptConnID := uint16(0) + acceptRemote := network.Addr{} + acceptErr := errors.New("accept error") + + rpc := &MockRPCClient{} + rpc.On("Accept", lisID).Return(acceptConnID, acceptRemote, acceptErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + conn, err := lis.Accept() + require.Equal(t, acceptErr, err) + require.Nil(t, conn) + }) +} + +func TestListener_Close(t *testing.T) { + l := logging.MustGetLogger("app2_listener") + + lisID := uint16(1) + localPK, _ := cipher.GenerateKeyPair() + local := network.Addr{ + Net: network.TypeDMSG, + PubKey: localPK, + Port: routing.Port(100), + } + + t.Run("ok", func(t *testing.T) { + var closeNoErr error + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("CloseListener", lisID).Return(closeNoErr) + + cm := newIDManager() + + connID1 := uint16(1) + connID2 := uint16(2) + connID3 := uint16(3) + + rpc.On("CloseConn", connID1).Return(closeNoErr) + rpc.On("CloseConn", connID2).Return(closeErr) + rpc.On("CloseConn", connID3).Return(closeNoErr) + + conn1 := &Conn{id: connID1, rpc: rpc} + free1, err := cm.add(connID1, conn1) + require.NoError(t, err) + conn1.freeConn = free1 + + conn2 := &Conn{id: connID2, rpc: rpc} + free2, err := cm.add(connID2, conn2) + require.NoError(t, err) + conn2.freeConn = free2 + + conn3 := &Conn{id: connID3, rpc: rpc} + free3, err := cm.add(connID3, conn3) + require.NoError(t, err) + conn3.freeConn = free3 + + lis := &Listener{ + log: l, + id: lisID, + rpc: rpc, + addr: local, + cm: cm, + freeLis: func() {}, + } + + err = lis.Close() + require.NoError(t, err) + + _, ok := lis.cm.values[connID1] + require.False(t, ok) + + _, ok = lis.cm.values[connID2] + require.False(t, ok) + + _, ok = lis.cm.values[connID3] + require.False(t, ok) + }) + + t.Run("close error", func(t *testing.T) { + lisCloseErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("CloseListener", lisID).Return(lisCloseErr) + + lis := &Listener{ + log: l, + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + err := lis.Close() + require.Equal(t, err, lisCloseErr) + }) +} diff --git a/pkg/app2/mock_conn.go b/pkg/app2/mock_conn.go new file mode 100644 index 000000000..981c8c308 --- /dev/null +++ b/pkg/app2/mock_conn.go @@ -0,0 +1,141 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package app2 + +import ( + "net" + "time" + + "github.com/stretchr/testify/mock" +) + +// MockConn is an autogenerated mock type for the Conn type +type MockConn struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *MockConn) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LocalAddr provides a mock function with given fields: +func (_m *MockConn) LocalAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(net.Addr) + } + + return r0 +} + +// Read provides a mock function with given fields: b +func (_m *MockConn) Read(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemoteAddr provides a mock function with given fields: +func (_m *MockConn) RemoteAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(net.Addr) + } + + return r0 +} + +// SetDeadline provides a mock function with given fields: t +func (_m *MockConn) SetDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetReadDeadline provides a mock function with given fields: t +func (_m *MockConn) SetReadDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetWriteDeadline provides a mock function with given fields: t +func (_m *MockConn) SetWriteDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Write provides a mock function with given fields: b +func (_m *MockConn) Write(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/app2/mock_listener.go b/pkg/app2/mock_listener.go new file mode 100644 index 000000000..44fda81bd --- /dev/null +++ b/pkg/app2/mock_listener.go @@ -0,0 +1,65 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package app2 + +import ( + "net" + + "github.com/stretchr/testify/mock" +) + +// MockListener is an autogenerated mock type for the Listener type +type MockListener struct { + mock.Mock +} + +// Accept provides a mock function with given fields: +func (_m *MockListener) Accept() (net.Conn, error) { + ret := _m.Called() + + var r0 net.Conn + if rf, ok := ret.Get(0).(func() net.Conn); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Addr provides a mock function with given fields: +func (_m *MockListener) Addr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(net.Addr) + } + + return r0 +} + +// Close provides a mock function with given fields: +func (_m *MockListener) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/app2/mock_rpc_client.go b/pkg/app2/mock_rpc_client.go new file mode 100644 index 000000000..bbf373f93 --- /dev/null +++ b/pkg/app2/mock_rpc_client.go @@ -0,0 +1,159 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package app2 + +import mock "github.com/stretchr/testify/mock" +import network "github.com/skycoin/skywire/pkg/app2/network" +import routing "github.com/skycoin/skywire/pkg/routing" + +// MockRPCClient is an autogenerated mock type for the RPCClient type +type MockRPCClient struct { + mock.Mock +} + +// Accept provides a mock function with given fields: lisID +func (_m *MockRPCClient) Accept(lisID uint16) (uint16, network.Addr, error) { + ret := _m.Called(lisID) + + var r0 uint16 + if rf, ok := ret.Get(0).(func(uint16) uint16); ok { + r0 = rf(lisID) + } else { + r0 = ret.Get(0).(uint16) + } + + var r1 network.Addr + if rf, ok := ret.Get(1).(func(uint16) network.Addr); ok { + r1 = rf(lisID) + } else { + r1 = ret.Get(1).(network.Addr) + } + + var r2 error + if rf, ok := ret.Get(2).(func(uint16) error); ok { + r2 = rf(lisID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CloseConn provides a mock function with given fields: id +func (_m *MockRPCClient) CloseConn(id uint16) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(uint16) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CloseListener provides a mock function with given fields: id +func (_m *MockRPCClient) CloseListener(id uint16) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(uint16) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Dial provides a mock function with given fields: remote +func (_m *MockRPCClient) Dial(remote network.Addr) (uint16, routing.Port, error) { + ret := _m.Called(remote) + + var r0 uint16 + if rf, ok := ret.Get(0).(func(network.Addr) uint16); ok { + r0 = rf(remote) + } else { + r0 = ret.Get(0).(uint16) + } + + var r1 routing.Port + if rf, ok := ret.Get(1).(func(network.Addr) routing.Port); ok { + r1 = rf(remote) + } else { + r1 = ret.Get(1).(routing.Port) + } + + var r2 error + if rf, ok := ret.Get(2).(func(network.Addr) error); ok { + r2 = rf(remote) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Listen provides a mock function with given fields: local +func (_m *MockRPCClient) Listen(local network.Addr) (uint16, error) { + ret := _m.Called(local) + + var r0 uint16 + if rf, ok := ret.Get(0).(func(network.Addr) uint16); ok { + r0 = rf(local) + } else { + r0 = ret.Get(0).(uint16) + } + + var r1 error + if rf, ok := ret.Get(1).(func(network.Addr) error); ok { + r1 = rf(local) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Read provides a mock function with given fields: connID, b +func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, error) { + ret := _m.Called(connID, b) + + var r0 int + if rf, ok := ret.Get(0).(func(uint16, []byte) int); ok { + r0 = rf(connID, b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { + r1 = rf(connID, b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Write provides a mock function with given fields: connID, b +func (_m *MockRPCClient) Write(connID uint16, b []byte) (int, error) { + ret := _m.Called(connID, b) + + var r0 int + if rf, ok := ret.Get(0).(func(uint16, []byte) int); ok { + r0 = rf(connID, b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { + r1 = rf(connID, b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go new file mode 100644 index 000000000..d96aabc2a --- /dev/null +++ b/pkg/app2/network/addr.go @@ -0,0 +1,51 @@ +package network + +import ( + "errors" + "fmt" + "net" + + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/skywire/pkg/routing" +) + +var ( + ErrUnknownAddrType = errors.New("addr type is unknown") +) + +// Addr implements net.Addr for network addresses. +type Addr struct { + Net Type + PubKey cipher.PubKey + Port routing.Port +} + +// Network returns network type. +func (a Addr) Network() string { + return string(a.Net) +} + +// String returns public key and port of node split by colon. +func (a Addr) String() string { + if a.Port == 0 { + return fmt.Sprintf("%s:~", a.PubKey) + } + return fmt.Sprintf("%s:%d", a.PubKey, a.Port) +} + +// ConvertAddr asserts type of the passed `net.Addr` and converts it +// to `Addr` if possible. +func ConvertAddr(addr net.Addr) (Addr, error) { + switch a := addr.(type) { + case dmsg.Addr: + return Addr{ + Net: TypeDMSG, + PubKey: a.PK, + Port: routing.Port(a.Port), + }, nil + default: + return Addr{}, ErrUnknownAddrType + } +} diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go new file mode 100644 index 000000000..424b0df46 --- /dev/null +++ b/pkg/app2/network/dmsg_networker.go @@ -0,0 +1,40 @@ +package network + +import ( + "context" + "net" + + "github.com/skycoin/dmsg" +) + +// DMSGNetworker implements `Networker` for dmsg network. +type DMSGNetworker struct { + dmsgC *dmsg.Client +} + +// NewDMSGNetworker constructs new `DMSGNetworker`. +func NewDMSGNetworker(dmsgC *dmsg.Client) Networker { + return &DMSGNetworker{ + dmsgC: dmsgC, + } +} + +// Dial dials remote `addr` via dmsg network. +func (n *DMSGNetworker) Dial(addr Addr) (net.Conn, error) { + return n.DialContext(context.Background(), addr) +} + +// DialContext dials remote `addr` via dmsg network with context. +func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { + return n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) +} + +// Listen starts listening on local `addr` in the dmsg network. +func (n *DMSGNetworker) Listen(addr Addr) (net.Listener, error) { + return n.ListenContext(context.Background(), addr) +} + +// ListenContext starts listening on local `addr` in the dmsg network with context. +func (n *DMSGNetworker) ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { + return n.dmsgC.Listen(uint16(addr.Port)) +} diff --git a/pkg/app2/network/mock_networker.go b/pkg/app2/network/mock_networker.go new file mode 100644 index 000000000..fd4b304ae --- /dev/null +++ b/pkg/app2/network/mock_networker.go @@ -0,0 +1,104 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package network + +import context "context" +import mock "github.com/stretchr/testify/mock" +import net "net" + +// MockNetworker is an autogenerated mock type for the Networker type +type MockNetworker struct { + mock.Mock +} + +// Dial provides a mock function with given fields: addr +func (_m *MockNetworker) Dial(addr Addr) (net.Conn, error) { + ret := _m.Called(addr) + + var r0 net.Conn + if rf, ok := ret.Get(0).(func(Addr) net.Conn); ok { + r0 = rf(addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(Addr) error); ok { + r1 = rf(addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DialContext provides a mock function with given fields: ctx, addr +func (_m *MockNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { + ret := _m.Called(ctx, addr) + + var r0 net.Conn + if rf, ok := ret.Get(0).(func(context.Context, Addr) net.Conn); ok { + r0 = rf(ctx, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, Addr) error); ok { + r1 = rf(ctx, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Listen provides a mock function with given fields: addr +func (_m *MockNetworker) Listen(addr Addr) (net.Listener, error) { + ret := _m.Called(addr) + + var r0 net.Listener + if rf, ok := ret.Get(0).(func(Addr) net.Listener); ok { + r0 = rf(addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Listener) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(Addr) error); ok { + r1 = rf(addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListenContext provides a mock function with given fields: ctx, addr +func (_m *MockNetworker) ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { + ret := _m.Called(ctx, addr) + + var r0 net.Listener + if rf, ok := ret.Get(0).(func(context.Context, Addr) net.Listener); ok { + r0 = rf(ctx, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Listener) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, Addr) error); ok { + r1 = rf(ctx, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/app2/network/networker.go b/pkg/app2/network/networker.go new file mode 100644 index 000000000..41fb18b6c --- /dev/null +++ b/pkg/app2/network/networker.go @@ -0,0 +1,94 @@ +package network + +import ( + "context" + "errors" + "net" + "sync" +) + +//go:generate mockery -name Networker -case underscore -inpkg + +var ( + // ErrNoSuchNetworker is being returned when there's no suitable networker. + ErrNoSuchNetworker = errors.New("no such networker") + // ErrNetworkerAlreadyExists is being returned when there's already one with such Network type. + ErrNetworkerAlreadyExists = errors.New("networker already exists") +) + +var ( + networkers = make(map[Type]Networker) + networkersMx sync.RWMutex +) + +// AddNetworker associates Networker with the `network`. +func AddNetworker(t Type, n Networker) error { + networkersMx.Lock() + defer networkersMx.Unlock() + + if _, ok := networkers[t]; ok { + return ErrNetworkerAlreadyExists + } + + networkers[t] = n + + return nil +} + +// ResolveNetworker resolves Networker by `network`. +func ResolveNetworker(t Type) (Networker, error) { + networkersMx.RLock() + n, ok := networkers[t] + if !ok { + networkersMx.RUnlock() + return nil, ErrNoSuchNetworker + } + networkersMx.RUnlock() + return n, nil +} + +// ClearNetworkers removes all the stored networkers. +func ClearNetworkers() { + networkersMx.Lock() + defer networkersMx.Unlock() + + networkers = make(map[Type]Networker) +} + +// Networker defines basic network operations, such as Dial/Listen. +type Networker interface { + Dial(addr Addr) (net.Conn, error) + DialContext(ctx context.Context, addr Addr) (net.Conn, error) + Listen(addr Addr) (net.Listener, error) + ListenContext(ctx context.Context, addr Addr) (net.Listener, error) +} + +// Dial dials the remote `addr`. +func Dial(addr Addr) (net.Conn, error) { + return DialContext(context.Background(), addr) +} + +// DialContext dials the remote `addr` with the context. +func DialContext(ctx context.Context, addr Addr) (net.Conn, error) { + n, err := ResolveNetworker(addr.Net) + if err != nil { + return nil, err + } + + return n.DialContext(ctx, addr) +} + +// Listen starts listening on the local `addr`. +func Listen(addr Addr) (net.Listener, error) { + return ListenContext(context.Background(), addr) +} + +// ListenContext starts listening on the local `addr` with the context. +func ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { + networker, err := ResolveNetworker(addr.Net) + if err != nil { + return nil, err + } + + return networker.ListenContext(ctx, addr) +} diff --git a/pkg/app2/network/networker_test.go b/pkg/app2/network/networker_test.go new file mode 100644 index 000000000..f6692b838 --- /dev/null +++ b/pkg/app2/network/networker_test.go @@ -0,0 +1,115 @@ +package network + +import ( + "context" + "net" + "testing" + + "github.com/skycoin/dmsg/cipher" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/routing" +) + +func TestAddNetworker(t *testing.T) { + ClearNetworkers() + + nType := TypeDMSG + var n Networker + + err := AddNetworker(nType, n) + require.NoError(t, err) + + err = AddNetworker(nType, n) + require.Equal(t, err, ErrNetworkerAlreadyExists) +} + +func TestResolveNetworker(t *testing.T) { + ClearNetworkers() + + nType := TypeDMSG + var n Networker + + n, err := ResolveNetworker(nType) + require.Equal(t, err, ErrNoSuchNetworker) + + err = AddNetworker(nType, n) + require.NoError(t, err) + + gotN, err := ResolveNetworker(nType) + require.NoError(t, err) + require.Equal(t, gotN, n) +} + +func TestDial(t *testing.T) { + addr := prepAddr() + + t.Run("no such networker", func(t *testing.T) { + ClearNetworkers() + + _, err := Dial(addr) + require.Equal(t, err, ErrNoSuchNetworker) + }) + + t.Run("ok", func(t *testing.T) { + ClearNetworkers() + + dialCtx := context.Background() + var ( + dialConn net.Conn + dialErr error + ) + + n := &MockNetworker{} + n.On("DialContext", dialCtx, addr).Return(dialConn, dialErr) + + err := AddNetworker(addr.Net, n) + require.NoError(t, err) + + conn, err := Dial(addr) + require.NoError(t, err) + require.Equal(t, conn, dialConn) + }) +} + +func TestListen(t *testing.T) { + addr := prepAddr() + + t.Run("no such networker", func(t *testing.T) { + ClearNetworkers() + + _, err := Listen(addr) + require.Equal(t, err, ErrNoSuchNetworker) + }) + + t.Run("ok", func(t *testing.T) { + ClearNetworkers() + + listenCtx := context.Background() + var ( + listenLis net.Listener + listenErr error + ) + + n := &MockNetworker{} + n.On("ListenContext", listenCtx, addr).Return(listenLis, listenErr) + + err := AddNetworker(addr.Net, n) + require.NoError(t, err) + + lis, err := Listen(addr) + require.NoError(t, err) + require.Equal(t, lis, listenLis) + }) +} + +func prepAddr() Addr { + addrPK, _ := cipher.GenerateKeyPair() + addrPort := routing.Port(100) + + return Addr{ + Net: TypeDMSG, + PubKey: addrPK, + Port: addrPort, + } +} diff --git a/pkg/app2/network/type.go b/pkg/app2/network/type.go new file mode 100644 index 000000000..c91f9128d --- /dev/null +++ b/pkg/app2/network/type.go @@ -0,0 +1,21 @@ +package network + +// Type represents the network type. +type Type string + +const ( + // TypeDMSG is a network type for DMSG communication. + TypeDMSG Type = "dmsg" +) + +// IsValid checks whether the network contains valid value for the type. +func (n Type) IsValid() bool { + _, ok := validNetworks[n] + return ok +} + +var ( + validNetworks = map[Type]struct{}{ + TypeDMSG: {}, + } +) diff --git a/pkg/app2/network/type_test.go b/pkg/app2/network/type_test.go new file mode 100644 index 000000000..632ade9c7 --- /dev/null +++ b/pkg/app2/network/type_test.go @@ -0,0 +1,32 @@ +package network + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestType_IsValid(t *testing.T) { + tt := []struct { + name string + t Type + want bool + }{ + { + name: "valid", + t: TypeDMSG, + want: true, + }, + { + name: "not valid", + t: "not valid", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + valid := tc.t.IsValid() + require.Equal(t, tc.want, valid) + }) + } +} diff --git a/pkg/app2/network/wrapped_conn.go b/pkg/app2/network/wrapped_conn.go new file mode 100644 index 000000000..d8e1f4df9 --- /dev/null +++ b/pkg/app2/network/wrapped_conn.go @@ -0,0 +1,42 @@ +package network + +import ( + "net" +) + +// WrappedConn wraps `net.Conn` to support address conversion between +// specific `net.Addr` implementations and `Addr`. +type WrappedConn struct { + net.Conn + local Addr + remote Addr +} + +// WrapConn wraps passed `conn`. Handles `net.Addr` type assertion. +func WrapConn(conn net.Conn) (net.Conn, error) { + l, err := ConvertAddr(conn.LocalAddr()) + if err != nil { + return nil, err + } + + r, err := ConvertAddr(conn.RemoteAddr()) + if err != nil { + return nil, err + } + + return &WrappedConn{ + Conn: conn, + local: l, + remote: r, + }, nil +} + +// LocalAddr returns local address. +func (c *WrappedConn) LocalAddr() net.Addr { + return c.local +} + +// RemoteAddr returns remote address. +func (c *WrappedConn) RemoteAddr() net.Addr { + return c.remote +} diff --git a/pkg/app2/procid.go b/pkg/app2/procid.go new file mode 100644 index 000000000..d49cdec5f --- /dev/null +++ b/pkg/app2/procid.go @@ -0,0 +1,6 @@ +package app2 + +// ProcID identifies the current instance of an app (an app process). +// The visor node is responsible for starting apps, and the started process +// should be provided with a ProcID. +type ProcID uint16 diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go new file mode 100644 index 000000000..a97773f27 --- /dev/null +++ b/pkg/app2/rpc_client.go @@ -0,0 +1,105 @@ +package app2 + +import ( + "net/rpc" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +//go:generate mockery -name RPCClient -case underscore -inpkg + +// RPCClient describes RPC interface to communicate with the server. +type RPCClient interface { + Dial(remote network.Addr) (connID uint16, localPort routing.Port, err error) + Listen(local network.Addr) (uint16, error) + Accept(lisID uint16) (connID uint16, remote network.Addr, err error) + Write(connID uint16, b []byte) (int, error) + Read(connID uint16, b []byte) (int, error) + CloseConn(id uint16) error + CloseListener(id uint16) error +} + +// rpcClient implements `RPCClient`. +type rpcCLient struct { + rpc *rpc.Client +} + +// NewRPCClient constructs new `rpcClient`. +func NewRPCClient(rpc *rpc.Client) RPCClient { + return &rpcCLient{ + rpc: rpc, + } +} + +// Dial sends `Dial` command to the server. +func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing.Port, err error) { + var resp DialResp + if err := c.rpc.Call("RPCGateway.Dial", &remote, &resp); err != nil { + return 0, 0, err + } + + return resp.ConnID, resp.LocalPort, nil +} + +// Listen sends `Listen` command to the server. +func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { + var lisID uint16 + if err := c.rpc.Call("RPCGateway.Listen", &local, &lisID); err != nil { + return 0, err + } + + return lisID, nil +} + +// Accept sends `Accept` command to the server. +func (c *rpcCLient) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) { + var acceptResp AcceptResp + if err := c.rpc.Call("RPCGateway.Accept", &lisID, &acceptResp); err != nil { + return 0, network.Addr{}, err + } + + return acceptResp.ConnID, acceptResp.Remote, nil +} + +// Write sends `Write` command to the server. +func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { + req := WriteReq{ + ConnID: connID, + B: b, + } + + var n int + if err := c.rpc.Call("RPCGateway.Write", &req, &n); err != nil { + return n, err + } + + return n, nil +} + +// Read sends `Read` command to the server. +func (c *rpcCLient) Read(connID uint16, b []byte) (int, error) { + req := ReadReq{ + ConnID: connID, + BufLen: len(b), + } + + var resp ReadResp + if err := c.rpc.Call("RPCGateway.Read", &req, &resp); err != nil { + return 0, err + } + + copy(b[:resp.N], resp.B[:resp.N]) + + return resp.N, nil +} + +// CloseConn sends `CloseConn` command to the server. +func (c *rpcCLient) CloseConn(id uint16) error { + return c.rpc.Call("RPCGateway.CloseConn", &id, nil) +} + +// CloseListener sends `CloseListener` command to the server. +func (c *rpcCLient) CloseListener(id uint16) error { + return c.rpc.Call("RPCGateway.CloseListener", &id, nil) +} diff --git a/pkg/app2/rpc_client_test.go b/pkg/app2/rpc_client_test.go new file mode 100644 index 000000000..2d3476780 --- /dev/null +++ b/pkg/app2/rpc_client_test.go @@ -0,0 +1,494 @@ +package app2 + +import ( + "context" + "net" + "net/rpc" + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + "golang.org/x/net/nettest" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestRPCClient_Dial(t *testing.T) { + t.Run("ok", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + remoteNet := network.TypeDMSG + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(100) + remote := network.Addr{ + Net: remoteNet, + PubKey: remotePK, + Port: remotePort, + } + + localPK, _ := cipher.GenerateKeyPair() + dmsgLocal := dmsg.Addr{ + PK: localPK, + Port: 101, + } + dmsgRemote := dmsg.Addr{ + PK: remotePK, + Port: uint16(remotePort), + } + + dialCtx := context.Background() + dialConn := dmsg.NewTransport(&MockConn{}, logging.MustGetLogger("dmsg_tp"), + dmsgLocal, dmsgRemote, 0, func() {}) + var noErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, remote).Return(dialConn, noErr) + + network.ClearNetworkers() + err := network.AddNetworker(remoteNet, n) + require.NoError(t, err) + + connID, localPort, err := cl.Dial(remote) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + require.Equal(t, localPort, routing.Port(dmsgLocal.Port)) + + }) + + t.Run("dial error", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + remoteNet := network.TypeDMSG + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(100) + remote := network.Addr{ + Net: remoteNet, + PubKey: remotePK, + Port: remotePort, + } + + dialCtx := context.Background() + var dialConn net.Conn + dialErr := errors.New("dial error") + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, remote).Return(dialConn, dialErr) + + network.ClearNetworkers() + err := network.AddNetworker(remoteNet, n) + require.NoError(t, err) + + connID, localPort, err := cl.Dial(remote) + require.Error(t, err) + require.Equal(t, err.Error(), dialErr.Error()) + require.Equal(t, connID, uint16(0)) + require.Equal(t, localPort, routing.Port(0)) + }) +} + +func TestRPCClient_Listen(t *testing.T) { + t.Run("ok", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + localNet := network.TypeDMSG + localPK, _ := cipher.GenerateKeyPair() + localPort := routing.Port(100) + local := network.Addr{ + Net: localNet, + PubKey: localPK, + Port: localPort, + } + + listenCtx := context.Background() + var listenLis net.Listener + var noErr error + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, local).Return(listenLis, noErr) + + network.ClearNetworkers() + err := network.AddNetworker(localNet, n) + require.NoError(t, err) + + lisID, err := cl.Listen(local) + require.NoError(t, err) + require.Equal(t, lisID, uint16(1)) + }) + + t.Run("listen error", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + localNet := network.TypeDMSG + localPK, _ := cipher.GenerateKeyPair() + localPort := routing.Port(100) + local := network.Addr{ + Net: localNet, + PubKey: localPK, + Port: localPort, + } + + listenCtx := context.Background() + var listenLis net.Listener + listenErr := errors.New("listen error") + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, local).Return(listenLis, listenErr) + + network.ClearNetworkers() + err := network.AddNetworker(localNet, n) + require.NoError(t, err) + + lisID, err := cl.Listen(local) + require.Error(t, err) + require.Equal(t, err.Error(), listenErr.Error()) + require.Equal(t, lisID, uint16(0)) + }) +} + +func TestRPCClient_Accept(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + localPK, _ := cipher.GenerateKeyPair() + localPort := uint16(100) + dmsgLocal := dmsg.Addr{ + PK: localPK, + Port: localPort, + } + remotePK, _ := cipher.GenerateKeyPair() + remotePort := uint16(101) + dmsgRemote := dmsg.Addr{ + PK: remotePK, + Port: remotePort, + } + lisConn := dmsg.NewTransport(&MockConn{}, logging.MustGetLogger("dmsg_tp"), + dmsgLocal, dmsgRemote, 0, func() {}) + var noErr error + + lis := &MockListener{} + lis.On("Accept").Return(lisConn, noErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + wantRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: remotePK, + Port: routing.Port(remotePort), + } + + connID, remote, err := cl.Accept(lisID) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + require.Equal(t, remote, wantRemote) + }) + + t.Run("accept error", func(t *testing.T) { + gateway := prepGateway() + + var lisConn net.Conn + listenErr := errors.New("accept error") + + lis := &MockListener{} + lis.On("Accept").Return(lisConn, listenErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + connID, remote, err := cl.Accept(lisID) + require.Error(t, err) + require.Equal(t, err.Error(), listenErr.Error()) + require.Equal(t, connID, uint16(0)) + require.Equal(t, remote, network.Addr{}) + }) +} + +func TestRPCClient_Write(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 10 + var noErr error + + conn := &MockConn{} + conn.On("Write", writeBuf).Return(writeN, noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Write(connID, writeBuf) + require.NoError(t, err) + require.Equal(t, n, writeN) + }) + + t.Run("write error", func(t *testing.T) { + gateway := prepGateway() + + writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 0 + writeErr := errors.New("write error") + + conn := &MockConn{} + conn.On("Write", writeBuf).Return(writeN, writeErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Write(connID, writeBuf) + require.Error(t, err) + require.Equal(t, err.Error(), writeErr.Error()) + require.Equal(t, n, 0) + }) +} + +func TestRPCClient_Read(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + readN := 5 + var noErr error + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Read(connID, readBuf) + require.NoError(t, err) + require.Equal(t, n, readN) + }) + + t.Run("read error", func(t *testing.T) { + gateway := prepGateway() + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + readN := 0 + readErr := errors.New("read error") + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Read(connID, readBuf) + require.Error(t, err) + require.Equal(t, err.Error(), readErr.Error()) + require.Equal(t, n, readN) + }) +} + +func TestRPCClient_CloseConn(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + var noErr error + + conn := &MockConn{} + conn.On("Close").Return(noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseConn(connID) + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + gateway := prepGateway() + + closeErr := errors.New("close error") + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseConn(connID) + require.Error(t, err) + require.Equal(t, err.Error(), closeErr.Error()) + }) +} + +func TestRPCClient_CloseListener(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + var noErr error + + lis := &MockListener{} + lis.On("Close").Return(noErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseListener(lisID) + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + gateway := prepGateway() + + closeErr := errors.New("close error") + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseListener(lisID) + require.Error(t, err) + require.Equal(t, err.Error(), closeErr.Error()) + }) +} + +func prepGateway() *RPCGateway { + l := logging.MustGetLogger("rpc_gateway") + return newRPCGateway(l) +} + +func prepRPCServer(t *testing.T, gateway *RPCGateway) *rpc.Server { + s := rpc.NewServer() + err := s.Register(gateway) + require.NoError(t, err) + + return s +} + +func prepListener(t *testing.T) (lis net.Listener, cleanup func()) { + lis, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + + return lis, func() { + err := lis.Close() + require.NoError(t, err) + } +} + +func prepClient(t *testing.T, network, addr string) RPCClient { + rpcCl, err := rpc.Dial(network, addr) + require.NoError(t, err) + + return NewRPCClient(rpcCl) +} diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go new file mode 100644 index 000000000..dd1387131 --- /dev/null +++ b/pkg/app2/rpc_gateway.go @@ -0,0 +1,258 @@ +package app2 + +import ( + "fmt" + "net" + + "github.com/pkg/errors" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +// RPCGateway is a RPC interface for the app server. +type RPCGateway struct { + lm *idManager // contains listeners associated with their IDs + cm *idManager // contains connections associated with their IDs + log *logging.Logger +} + +// newRPCGateway constructs new server RPC interface. +func newRPCGateway(log *logging.Logger) *RPCGateway { + return &RPCGateway{ + lm: newIDManager(), + cm: newIDManager(), + log: log, + } +} + +// DialResp contains response parameters for `Dial`. +type DialResp struct { + ConnID uint16 + LocalPort routing.Port +} + +// Dial dials to the remote. +func (r *RPCGateway) Dial(remote *network.Addr, resp *DialResp) error { + reservedConnID, free, err := r.cm.reserveNextID() + if err != nil { + return err + } + + conn, err := network.Dial(*remote) + if err != nil { + free() + return err + } + + wrappedConn, err := network.WrapConn(conn) + if err != nil { + free() + return err + } + + if err := r.cm.set(*reservedConnID, wrappedConn); err != nil { + if err := wrappedConn.Close(); err != nil { + r.log.WithError(err).Error("error closing conn") + } + + free() + return err + } + + localAddr := wrappedConn.LocalAddr().(network.Addr) + + resp.ConnID = *reservedConnID + resp.LocalPort = localAddr.Port + + return nil +} + +// Listen starts listening. +func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { + nextLisID, free, err := r.lm.reserveNextID() + if err != nil { + return err + } + + l, err := network.Listen(*local) + if err != nil { + free() + return err + } + + if err := r.lm.set(*nextLisID, l); err != nil { + if err := l.Close(); err != nil { + r.log.WithError(err).Error("error closing listener") + } + + free() + return err + } + + *lisID = *nextLisID + + return nil +} + +// AcceptResp contains response parameters for `Accept`. +type AcceptResp struct { + Remote network.Addr + ConnID uint16 +} + +// Accept accepts connection from the listener specified by `lisID`. +func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { + lis, err := r.getListener(*lisID) + if err != nil { + return err + } + + connID, free, err := r.cm.reserveNextID() + if err != nil { + return err + } + + conn, err := lis.Accept() + if err != nil { + free() + return err + } + + wrappedConn, err := network.WrapConn(conn) + if err != nil { + free() + return err + } + + if err := r.cm.set(*connID, wrappedConn); err != nil { + if err := wrappedConn.Close(); err != nil { + r.log.WithError(err).Error("error closing DMSG transport") + } + + free() + return err + } + + remote := wrappedConn.RemoteAddr().(network.Addr) + + resp.Remote = remote + resp.ConnID = *connID + + return nil +} + +// WriteReq contains arguments for `Write`. +type WriteReq struct { + ConnID uint16 + B []byte +} + +// Write writes to the connection. +func (r *RPCGateway) Write(req *WriteReq, n *int) error { + conn, err := r.getConn(req.ConnID) + if err != nil { + return err + } + + *n, err = conn.Write(req.B) + if err != nil { + return err + } + + return nil +} + +// ReadReq contains arguments for `Read`. +type ReadReq struct { + ConnID uint16 + BufLen int +} + +// ReadResp contains response parameters for `Read`. +type ReadResp struct { + B []byte + N int +} + +// Read reads data from connection specified by `connID`. +func (r *RPCGateway) Read(req *ReadReq, resp *ReadResp) error { + conn, err := r.getConn(req.ConnID) + if err != nil { + return err + } + + buf := make([]byte, req.BufLen) + resp.N, err = conn.Read(buf) + if err != nil { + return err + } + + resp.B = make([]byte, resp.N) + copy(resp.B, buf[:resp.N]) + + return nil +} + +// CloseConn closes connection specified by `connID`. +func (r *RPCGateway) CloseConn(connID *uint16, _ *struct{}) error { + conn, err := r.popConn(*connID) + if err != nil { + return err + } + + return conn.Close() +} + +// CloseListener closes listener specified by `lisID`. +func (r *RPCGateway) CloseListener(lisID *uint16, _ *struct{}) error { + lis, err := r.popListener(*lisID) + if err != nil { + return err + } + + return lis.Close() +} + +// popListener gets listener from the manager by `lisID` and removes it. +// Handles type assertion. +func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { + lisIfc, err := r.lm.pop(lisID) + if err != nil { + return nil, errors.Wrap(err, "no listener") + } + + return assertListener(lisIfc) +} + +// popConn gets conn from the manager by `connID` and removes it. +// Handles type assertion. +func (r *RPCGateway) popConn(connID uint16) (net.Conn, error) { + connIfc, err := r.cm.pop(connID) + if err != nil { + return nil, errors.Wrap(err, "no conn") + } + + return assertConn(connIfc) +} + +// getListener gets listener from the manager by `lisID`. Handles type assertion. +func (r *RPCGateway) getListener(lisID uint16) (net.Listener, error) { + lisIfc, ok := r.lm.get(lisID) + if !ok { + return nil, fmt.Errorf("no listener with key %d", lisID) + } + + return assertListener(lisIfc) +} + +// getConn gets conn from the manager by `connID`. Handles type assertion. +func (r *RPCGateway) getConn(connID uint16) (net.Conn, error) { + connIfc, ok := r.cm.get(connID) + if !ok { + return nil, fmt.Errorf("no conn with key %d", connID) + } + + return assertConn(connIfc) +} diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go new file mode 100644 index 000000000..637100a74 --- /dev/null +++ b/pkg/app2/rpc_gateway_test.go @@ -0,0 +1,564 @@ +package app2 + +import ( + "context" + "math" + "net" + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestRPCGateway_Dial(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + nType := network.TypeDMSG + + dialAddr := prepAddr(nType) + + t.Run("ok", func(t *testing.T) { + network.ClearNetworkers() + + localPort := routing.Port(100) + + dialCtx := context.Background() + dialConn := dmsg.NewTransport(nil, nil, dmsg.Addr{Port: uint16(localPort)}, dmsg.Addr{}, 0, func() {}) + var dialErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var resp DialResp + err = rpc.Dial(&dialAddr, &resp) + require.NoError(t, err) + require.Equal(t, resp.ConnID, uint16(1)) + require.Equal(t, resp.LocalPort, localPort) + }) + + t.Run("no more slots for a new conn", func(t *testing.T) { + rpc := newRPCGateway(l) + for i := uint16(0); i < math.MaxUint16; i++ { + rpc.cm.values[i] = nil + } + rpc.cm.values[math.MaxUint16] = nil + + var resp DialResp + err := rpc.Dial(&dialAddr, &resp) + require.Equal(t, err, errNoMoreAvailableValues) + }) + + t.Run("dial error", func(t *testing.T) { + network.ClearNetworkers() + + dialCtx := context.Background() + var dialConn net.Conn + dialErr := errors.New("dial error") + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var resp DialResp + err = rpc.Dial(&dialAddr, &resp) + require.Equal(t, err, dialErr) + }) + + t.Run("error wrapping conn", func(t *testing.T) { + network.ClearNetworkers() + + dialCtx := context.Background() + dialConn := &MockConn{} + dialConn.On("LocalAddr").Return(routing.Addr{}) + dialConn.On("RemoteAddr").Return(routing.Addr{}) + var dialErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var resp DialResp + err = rpc.Dial(&dialAddr, &resp) + require.Equal(t, err, network.ErrUnknownAddrType) + }) +} + +func TestRPCGateway_Listen(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + nType := network.TypeDMSG + + listenAddr := prepAddr(nType) + + t.Run("ok", func(t *testing.T) { + network.ClearNetworkers() + + listenCtx := context.Background() + listenLis := &dmsg.Listener{} + var listenErr error + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, listenAddr).Return(listenLis, listenErr) + + err := network.AddNetworker(nType, n) + require.Equal(t, err, listenErr) + + rpc := newRPCGateway(l) + + var lisID uint16 + + err = rpc.Listen(&listenAddr, &lisID) + require.NoError(t, err) + require.Equal(t, lisID, uint16(1)) + }) + + t.Run("no more slots for a new listener", func(t *testing.T) { + rpc := newRPCGateway(l) + for i := uint16(0); i < math.MaxUint16; i++ { + rpc.lm.values[i] = nil + } + rpc.lm.values[math.MaxUint16] = nil + + var lisID uint16 + + err := rpc.Listen(&listenAddr, &lisID) + require.Equal(t, err, errNoMoreAvailableValues) + }) + + t.Run("listen error", func(t *testing.T) { + network.ClearNetworkers() + + listenCtx := context.Background() + var listenLis net.Listener + listenErr := errors.New("listen error") + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, listenAddr).Return(listenLis, listenErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var lisID uint16 + + err = rpc.Listen(&listenAddr, &lisID) + require.Equal(t, err, listenErr) + }) +} + +func TestRPCGateway_Accept(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + acceptConn := &dmsg.Transport{} + var acceptErr error + + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) + + lisID := addListener(t, rpc, lis) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.NoError(t, err) + require.Equal(t, resp.Remote, network.Addr{Net: network.TypeDMSG}) + }) + + t.Run("no such listener", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := uint16(1) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("listener is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := addListener(t, rpc, nil) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("no more slots for a new conn", func(t *testing.T) { + rpc := newRPCGateway(l) + for i := uint16(0); i < math.MaxUint16; i++ { + rpc.cm.values[i] = nil + } + rpc.cm.values[math.MaxUint16] = nil + + lisID := addListener(t, rpc, &MockListener{}) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Equal(t, err, errNoMoreAvailableValues) + }) + + t.Run("error wrapping conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + acceptConn := &MockConn{} + acceptConn.On("LocalAddr").Return(routing.Addr{}) + acceptConn.On("RemoteAddr").Return(routing.Addr{}) + var acceptErr error + + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) + + lisID := addListener(t, rpc, lis) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Equal(t, err, network.ErrUnknownAddrType) + }) + + t.Run("accept error", func(t *testing.T) { + rpc := newRPCGateway(l) + + var acceptConn net.Conn + acceptErr := errors.New("accept error") + + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) + + lisID := addListener(t, rpc, lis) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Equal(t, err, acceptErr) + }) +} + +func TestRPCGateway_Write(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + writeBuff := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 10 + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + var writeErr error + + conn := &MockConn{} + conn.On("Write", writeBuff).Return(writeN, writeErr) + + connID := addConn(t, rpc, conn) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.NoError(t, err) + require.Equal(t, n, writeN) + }) + + t.Run("no such conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := uint16(1) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("conn is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := addConn(t, rpc, nil) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("write error", func(t *testing.T) { + rpc := newRPCGateway(l) + + writeErr := errors.New("write error") + + conn := &MockConn{} + conn.On("Write", writeBuff).Return(writeN, writeErr) + + connID := addConn(t, rpc, conn) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.Error(t, err) + require.Equal(t, err, writeErr) + }) +} + +func TestRPCGateway_Read(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + readN := 10 + var readErr error + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := addConn(t, rpc, conn) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + wantResp := ReadResp{ + B: readBuf, + N: readN, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.NoError(t, err) + require.Equal(t, resp, wantResp) + }) + + t.Run("no such conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := uint16(1) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("conn is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := addConn(t, rpc, nil) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("read error", func(t *testing.T) { + rpc := newRPCGateway(l) + + readN := 0 + readErr := errors.New("read error") + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := addConn(t, rpc, conn) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.Equal(t, err, readErr) + }) +} + +func TestRPCGateway_CloseConn(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + var closeErr error + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := addConn(t, rpc, conn) + + err := rpc.CloseConn(&connID, nil) + require.NoError(t, err) + _, ok := rpc.cm.values[connID] + require.False(t, ok) + }) + + t.Run("no such conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := uint16(1) + + err := rpc.CloseConn(&connID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("conn is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := addConn(t, rpc, nil) + + err := rpc.CloseConn(&connID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("close error", func(t *testing.T) { + rpc := newRPCGateway(l) + + closeErr := errors.New("close error") + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := addConn(t, rpc, conn) + + err := rpc.CloseConn(&connID, nil) + require.Equal(t, err, closeErr) + }) +} + +func TestRPCGateway_CloseListener(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + var closeErr error + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := addListener(t, rpc, lis) + + err := rpc.CloseListener(&lisID, nil) + require.NoError(t, err) + _, ok := rpc.lm.values[lisID] + require.False(t, ok) + }) + + t.Run("no such listener", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := uint16(1) + + err := rpc.CloseListener(&lisID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("listener is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := addListener(t, rpc, nil) + + err := rpc.CloseListener(&lisID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("close error", func(t *testing.T) { + rpc := newRPCGateway(l) + + closeErr := errors.New("close error") + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := addListener(t, rpc, lis) + + err := rpc.CloseListener(&lisID, nil) + require.Equal(t, err, closeErr) + }) +} + +func prepAddr(nType network.Type) network.Addr { + pk, _ := cipher.GenerateKeyPair() + port := routing.Port(100) + + return network.Addr{ + Net: nType, + PubKey: pk, + Port: port, + } +} + +func addConn(t *testing.T, rpc *RPCGateway, conn net.Conn) uint16 { + connID, _, err := rpc.cm.reserveNextID() + require.NoError(t, err) + + err = rpc.cm.set(*connID, conn) + require.NoError(t, err) + + return *connID +} + +func addListener(t *testing.T, rpc *RPCGateway, lis net.Listener) uint16 { + lisID, _, err := rpc.lm.reserveNextID() + require.NoError(t, err) + + err = rpc.lm.set(*lisID, lis) + require.NoError(t, err) + + return *lisID +} diff --git a/vendor/github.com/pkg/errors/.gitignore b/vendor/github.com/pkg/errors/.gitignore new file mode 100644 index 000000000..daf913b1b --- /dev/null +++ b/vendor/github.com/pkg/errors/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/pkg/errors/.travis.yml b/vendor/github.com/pkg/errors/.travis.yml new file mode 100644 index 000000000..d4b92663b --- /dev/null +++ b/vendor/github.com/pkg/errors/.travis.yml @@ -0,0 +1,15 @@ +language: go +go_import_path: github.com/pkg/errors +go: + - 1.4.x + - 1.5.x + - 1.6.x + - 1.7.x + - 1.8.x + - 1.9.x + - 1.10.x + - 1.11.x + - tip + +script: + - go test -v ./... diff --git a/vendor/github.com/pkg/errors/LICENSE b/vendor/github.com/pkg/errors/LICENSE new file mode 100644 index 000000000..835ba3e75 --- /dev/null +++ b/vendor/github.com/pkg/errors/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2015, Dave Cheney +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/pkg/errors/README.md b/vendor/github.com/pkg/errors/README.md new file mode 100644 index 000000000..6483ba2af --- /dev/null +++ b/vendor/github.com/pkg/errors/README.md @@ -0,0 +1,52 @@ +# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors) [![Sourcegraph](https://sourcegraph.com/github.com/pkg/errors/-/badge.svg)](https://sourcegraph.com/github.com/pkg/errors?badge) + +Package errors provides simple error handling primitives. + +`go get github.com/pkg/errors` + +The traditional error handling idiom in Go is roughly akin to +```go +if err != nil { + return err +} +``` +which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error. + +## Adding context to an error + +The errors.Wrap function returns a new error that adds context to the original error. For example +```go +_, err := ioutil.ReadAll(r) +if err != nil { + return errors.Wrap(err, "read failed") +} +``` +## Retrieving the cause of an error + +Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`. +```go +type causer interface { + Cause() error +} +``` +`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example: +```go +switch err := errors.Cause(err).(type) { +case *MyError: + // handle specifically +default: + // unknown error +} +``` + +[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors). + +## Contributing + +We welcome pull requests, bug fixes and issue reports. With that said, the bar for adding new symbols to this package is intentionally set high. + +Before proposing a change, please discuss your change by raising an issue. + +## License + +BSD-2-Clause diff --git a/vendor/github.com/pkg/errors/appveyor.yml b/vendor/github.com/pkg/errors/appveyor.yml new file mode 100644 index 000000000..a932eade0 --- /dev/null +++ b/vendor/github.com/pkg/errors/appveyor.yml @@ -0,0 +1,32 @@ +version: build-{build}.{branch} + +clone_folder: C:\gopath\src\github.com\pkg\errors +shallow_clone: true # for startup speed + +environment: + GOPATH: C:\gopath + +platform: + - x64 + +# http://www.appveyor.com/docs/installed-software +install: + # some helpful output for debugging builds + - go version + - go env + # pre-installed MinGW at C:\MinGW is 32bit only + # but MSYS2 at C:\msys64 has mingw64 + - set PATH=C:\msys64\mingw64\bin;%PATH% + - gcc --version + - g++ --version + +build_script: + - go install -v ./... + +test_script: + - set PATH=C:\gopath\bin;%PATH% + - go test -v ./... + +#artifacts: +# - path: '%GOPATH%\bin\*.exe' +deploy: off diff --git a/vendor/github.com/pkg/errors/errors.go b/vendor/github.com/pkg/errors/errors.go new file mode 100644 index 000000000..7421f326f --- /dev/null +++ b/vendor/github.com/pkg/errors/errors.go @@ -0,0 +1,282 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which when applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// together with the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required, the errors.WithStack and +// errors.WithMessage functions destructure errors.Wrap into its component +// operations: annotating an error with a stack trace and with a message, +// respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error that does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// Although the causer interface is not exported by this package, it is +// considered a part of its stable public interface. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported: +// +// %s print the error. If the error has a Cause it will be +// printed recursively. +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface: +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// The returned errors.StackTrace type is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d", f) +// } +// } +// +// Although the stackTracer interface is not exported by this package, it is +// considered a part of its stable public interface. +// +// See the documentation for Frame.Format for more details. +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is called, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +// WithMessagef annotates err with the format specifier. +// If err is nil, WithMessagef returns nil. +func WithMessagef(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/vendor/github.com/pkg/errors/stack.go b/vendor/github.com/pkg/errors/stack.go new file mode 100644 index 000000000..2874a048c --- /dev/null +++ b/vendor/github.com/pkg/errors/stack.go @@ -0,0 +1,147 @@ +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s function name and path of source file relative to the compile time +// GOPATH separated by \n\t (\n\t) +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format formats the stack of Frames according to the fmt.Formatter interface. +// +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+v Prints filename, function, and line number for each Frame in the stack. +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} diff --git a/vendor/github.com/prometheus/common/expfmt/text_create.go b/vendor/github.com/prometheus/common/expfmt/text_create.go index 8e473d0fe..0327865ee 100644 --- a/vendor/github.com/prometheus/common/expfmt/text_create.go +++ b/vendor/github.com/prometheus/common/expfmt/text_create.go @@ -14,9 +14,10 @@ package expfmt import ( - "bytes" + "bufio" "fmt" "io" + "io/ioutil" "math" "strconv" "strings" @@ -27,7 +28,7 @@ import ( dto "github.com/prometheus/client_model/go" ) -// enhancedWriter has all the enhanced write functions needed here. bytes.Buffer +// enhancedWriter has all the enhanced write functions needed here. bufio.Writer // implements it. type enhancedWriter interface { io.Writer @@ -37,14 +38,13 @@ type enhancedWriter interface { } const ( - initialBufSize = 512 initialNumBufSize = 24 ) var ( bufPool = sync.Pool{ New: func() interface{} { - return bytes.NewBuffer(make([]byte, 0, initialBufSize)) + return bufio.NewWriter(ioutil.Discard) }, } numBufPool = sync.Pool{ @@ -75,16 +75,14 @@ func MetricFamilyToText(out io.Writer, in *dto.MetricFamily) (written int, err e } // Try the interface upgrade. If it doesn't work, we'll use a - // bytes.Buffer from the sync.Pool and write out its content to out in a - // single go in the end. + // bufio.Writer from the sync.Pool. w, ok := out.(enhancedWriter) if !ok { - b := bufPool.Get().(*bytes.Buffer) - b.Reset() + b := bufPool.Get().(*bufio.Writer) + b.Reset(out) w = b defer func() { - bWritten, bErr := out.Write(b.Bytes()) - written = bWritten + bErr := b.Flush() if err == nil { err = bErr } diff --git a/vendor/github.com/prometheus/common/expfmt/text_parse.go b/vendor/github.com/prometheus/common/expfmt/text_parse.go index ec3d86ba7..342e5940d 100644 --- a/vendor/github.com/prometheus/common/expfmt/text_parse.go +++ b/vendor/github.com/prometheus/common/expfmt/text_parse.go @@ -325,7 +325,7 @@ func (p *TextParser) startLabelValue() stateFn { // - Other labels have to be added to currentLabels for signature calculation. if p.currentMF.GetType() == dto.MetricType_SUMMARY { if p.currentLabelPair.GetName() == model.QuantileLabel { - if p.currentQuantile, p.err = strconv.ParseFloat(p.currentLabelPair.GetValue(), 64); p.err != nil { + if p.currentQuantile, p.err = parseFloat(p.currentLabelPair.GetValue()); p.err != nil { // Create a more helpful error message. p.parseError(fmt.Sprintf("expected float as value for 'quantile' label, got %q", p.currentLabelPair.GetValue())) return nil @@ -337,7 +337,7 @@ func (p *TextParser) startLabelValue() stateFn { // Similar special treatment of histograms. if p.currentMF.GetType() == dto.MetricType_HISTOGRAM { if p.currentLabelPair.GetName() == model.BucketLabel { - if p.currentBucket, p.err = strconv.ParseFloat(p.currentLabelPair.GetValue(), 64); p.err != nil { + if p.currentBucket, p.err = parseFloat(p.currentLabelPair.GetValue()); p.err != nil { // Create a more helpful error message. p.parseError(fmt.Sprintf("expected float as value for 'le' label, got %q", p.currentLabelPair.GetValue())) return nil @@ -392,7 +392,7 @@ func (p *TextParser) readingValue() stateFn { if p.readTokenUntilWhitespace(); p.err != nil { return nil // Unexpected end of input. } - value, err := strconv.ParseFloat(p.currentToken.String(), 64) + value, err := parseFloat(p.currentToken.String()) if err != nil { // Create a more helpful error message. p.parseError(fmt.Sprintf("expected float as value, got %q", p.currentToken.String())) @@ -755,3 +755,10 @@ func histogramMetricName(name string) string { return name } } + +func parseFloat(s string) (float64, error) { + if strings.ContainsAny(s, "pP_") { + return 0, fmt.Errorf("unsupported character in float") + } + return strconv.ParseFloat(s, 64) +} diff --git a/vendor/github.com/skycoin/dmsg/addr.go b/vendor/github.com/skycoin/dmsg/addr.go deleted file mode 100644 index 2be739b40..000000000 --- a/vendor/github.com/skycoin/dmsg/addr.go +++ /dev/null @@ -1,26 +0,0 @@ -package dmsg - -import ( - "fmt" - - "github.com/skycoin/dmsg/cipher" -) - -// Addr implements net.Addr for skywire addresses. -type Addr struct { - PK cipher.PubKey - Port uint16 -} - -// Network returns "dmsg" -func (Addr) Network() string { - return Type -} - -// String returns public key and port of node split by colon. -func (a Addr) String() string { - if a.Port == 0 { - return fmt.Sprintf("%s:~", a.PK) - } - return fmt.Sprintf("%s:%d", a.PK, a.Port) -} diff --git a/vendor/github.com/skycoin/dmsg/client.go b/vendor/github.com/skycoin/dmsg/client.go index 587d03bcb..f09ffefde 100644 --- a/vendor/github.com/skycoin/dmsg/client.go +++ b/vendor/github.com/skycoin/dmsg/client.go @@ -57,7 +57,6 @@ type Client struct { pm *PortManager - // accept map[uint16]chan *transport done chan struct{} once sync.Once } @@ -70,10 +69,8 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, opts ...Cl sk: sk, dc: dc, conns: make(map[cipher.PubKey]*ClientConn), - pm: newPortManager(), - // accept: make(chan *transport, AcceptBufferSize), - // accept: make(map[uint16]chan *transport), - done: make(chan struct{}), + pm: newPortManager(pk), + done: make(chan struct{}), } for _, opt := range opts { if err := opt(c); err != nil { @@ -103,7 +100,7 @@ func (c *Client) updateDiscEntry(ctx context.Context) error { func (c *Client) setConn(ctx context.Context, conn *ClientConn) { c.mx.Lock() - c.conns[conn.remoteSrv] = conn + c.conns[conn.srvPK] = conn if err := c.updateDiscEntry(ctx); err != nil { c.log.WithError(err).Warn("updateEntry: failed") } @@ -142,7 +139,7 @@ func (c *Client) InitiateServerConnections(ctx context.Context, min int) error { if err != nil { return err } - c.log.Info("found dms_server entries:", entries) + c.log.Info("found dmsg.Server entries:", entries) if err := c.findOrConnectToServers(ctx, entries, min); err != nil { return err } @@ -213,7 +210,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) return nil, err } - conn := NewClientConn(c.log, nc, c.pk, srvPK, c.pm) + conn := NewClientConn(c.log, c.pm, nc, c.pk, srvPK) if err := conn.readOK(); err != nil { return nil, err } @@ -244,7 +241,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) // Listen creates a listener on a given port, adds it to port manager and returns the listener. func (c *Client) Listen(port uint16) (*Listener, error) { - l, ok := c.pm.NewListener(c.pk, port) + l, ok := c.pm.NewListener(port) if !ok { return nil, errors.New("port is busy") } @@ -288,7 +285,7 @@ func (c *Client) Type() string { // Close closes the dms_client and associated connections. // TODO(evaninjin): proper error handling. -func (c *Client) Close() error { +func (c *Client) Close() (err error) { if c == nil { return nil } @@ -305,13 +302,8 @@ func (c *Client) Close() error { c.conns = make(map[cipher.PubKey]*ClientConn) c.mx.Unlock() - c.pm.mu.Lock() - defer c.pm.mu.Unlock() - - for _, lis := range c.pm.listeners { - lis.close() - } + err = c.pm.Close() }) - return nil + return err } diff --git a/vendor/github.com/skycoin/dmsg/client_conn.go b/vendor/github.com/skycoin/dmsg/client_conn.go index 9ee1895af..be48e6adb 100644 --- a/vendor/github.com/skycoin/dmsg/client_conn.go +++ b/vendor/github.com/skycoin/dmsg/client_conn.go @@ -2,7 +2,6 @@ package dmsg import ( "context" - "encoding/json" "errors" "fmt" "net" @@ -18,9 +17,9 @@ import ( type ClientConn struct { log *logging.Logger - net.Conn // conn to dmsg server - local cipher.PubKey // local client's pk - remoteSrv cipher.PubKey // dmsg server's public key + net.Conn // conn to dmsg server + lPK cipher.PubKey // local client's pk + srvPK cipher.PubKey // dmsg server's public key // nextInitID keeps track of unused tp_ids to assign a future locally-initiated tp. // locally-initiated tps use an even tp_id between local and intermediary dms_server. @@ -38,12 +37,12 @@ type ClientConn struct { } // NewClientConn creates a new ClientConn. -func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey, pm *PortManager) *ClientConn { +func NewClientConn(log *logging.Logger, pm *PortManager, conn net.Conn, lPK, rPK cipher.PubKey) *ClientConn { cc := &ClientConn{ log: log, Conn: conn, - local: local, - remoteSrv: remote, + lPK: lPK, + srvPK: rPK, nextInitID: randID(true), tps: make(map[uint16]*Transport), pm: pm, @@ -54,7 +53,7 @@ func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubK } // RemotePK returns the remote Server's PK that the ClientConn is connected to. -func (c *ClientConn) RemotePK() cipher.PubKey { return c.remoteSrv } +func (c *ClientConn) RemotePK() cipher.PubKey { return c.srvPK } func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { for { @@ -76,7 +75,7 @@ func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { } } -func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort uint16) (*Transport, error) { +func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort uint16, closeCB func()) (*Transport, error) { c.mx.Lock() defer c.mx.Unlock() @@ -84,7 +83,10 @@ func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort if err != nil { return nil, err } - tp := NewTransport(c.Conn, c.log, Addr{c.local, lPort}, Addr{rPK, rPort}, id, c.delTp) + tp := NewTransport(c.Conn, c.log, Addr{c.lPK, lPort}, Addr{rPK, rPort}, id, func() { + c.delTp(id) + closeCB() + }) c.tps[id] = tp return tp, nil } @@ -116,72 +118,71 @@ func (c *ClientConn) setNextInitID(nextInitID uint16) { } func (c *ClientConn) readOK() error { - fr, err := readFrame(c.Conn) + _, df, err := readFrame(c.Conn) if err != nil { return errors.New("failed to get OK from server") } - - ft, _, _ := fr.Disassemble() - if ft != OkType { - return fmt.Errorf("wrong frame from server: %v", ft) + if df.Type != OkType { + return fmt.Errorf("wrong frame from server: %v", df.Type) } - return nil } -func (c *ClientConn) handleRequestFrame(id uint16, p []byte) (cipher.PubKey, error) { - // remotely-initiated tps should: - // - have a payload structured as HandshakePayload marshaled to JSON. - // - resp_pk should be of local client. - // - use an odd tp_id with the intermediary dmsg_server. - payload, err := unmarshalHandshakePayload(p) - if err != nil { - // TODO(nkryuchkov): When implementing reasons, send that payload format is incorrect. +// This handles 'REQUEST' frames which represent remotely-initiated tps. 'REQUEST' frames should: +// - have a HandshakePayload marshaled to JSON as payload. +// - have a resp_pk be of local client. +// - have an odd tp_id. +func (c *ClientConn) handleRequestFrame(log *logrus.Entry, id uint16, p []byte) (cipher.PubKey, error) { + + // The public key of the initiating client (or the client that sent the 'REQUEST' frame). + var initPK cipher.PubKey + + // Attempts to close tp due to given error. + // When we fail to close tp (a.k.a fail to send 'CLOSE' frame) or if the local client is closed, + // the connection to server should be closed. + // TODO(evanlinjin): derive close reason from error. + closeTp := func(origErr error) (cipher.PubKey, error) { if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { - return cipher.PubKey{}, err + log.WithError(err).Warn("handleRequestFrame: failed to close transport: ending conn to server.") + log.WithError(c.Close()).Warn("handleRequestFrame: closing connection to server.") + return initPK, origErr + } + switch origErr { + case ErrClientClosed: + log.WithError(c.Close()).Warn("handleRequestFrame: closing connection to server.") } - return cipher.PubKey{}, ErrRequestCheckFailed + return initPK, origErr } - if payload.RespPK != c.local || isInitiatorID(id) { - // TODO(nkryuchkov): When implementing reasons, send that payload is malformed. - if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { - return payload.InitPK, err - } - return payload.InitPK, ErrRequestCheckFailed + pay, err := unmarshalHandshakePayload(p) + if err != nil { + return closeTp(ErrRequestCheckFailed) // TODO(nkryuchkov): reason = payload format is incorrect. } + initPK = pay.InitAddr.PK - lis, ok := c.pm.Listener(payload.Port) + if pay.RespAddr.PK != c.lPK || isInitiatorID(id) { + return closeTp(ErrRequestCheckFailed) // TODO(nkryuchkov): reason = payload is malformed. + } + lis, ok := c.pm.Listener(pay.RespAddr.Port) if !ok { - // TODO(nkryuchkov): When implementing reasons, send that port is not listening - if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { - return payload.InitPK, err - } - return payload.InitPK, ErrPortNotListening + return closeTp(ErrPortNotListening) // TODO(nkryuchkov): reason = port is not listening. + } + if c.isClosed() { + return closeTp(ErrClientClosed) // TODO(nkryuchkov): reason = client is closed. } - tp := NewTransport(c.Conn, c.log, Addr{c.local, payload.Port}, Addr{payload.InitPK, 0}, id, c.delTp) // TODO: Have proper remote port. - - select { - case <-c.done: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return payload.InitPK, ErrClientClosed - - default: - err := lis.IntroduceTransport(tp) - if err == nil || err == ErrClientAcceptMaxed { - c.setTp(tp) - } - return payload.InitPK, err + tp := NewTransport(c.Conn, c.log, pay.RespAddr, pay.InitAddr, id, func() { c.delTp(id) }) + if err := lis.IntroduceTransport(tp); err != nil { + return initPK, err } + c.setTp(tp) + return initPK, nil } // Serve handles incoming frames. // Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. func (c *ClientConn) Serve(ctx context.Context) (err error) { - log := c.log.WithField("remoteServer", c.remoteSrv) + log := c.log.WithField("remoteServer", c.srvPK) log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") defer func() { c.close() @@ -190,50 +191,40 @@ func (c *ClientConn) Serve(ctx context.Context) (err error) { }() for { - f, err := readFrame(c.Conn) + f, df, err := readFrame(c.Conn) if err != nil { return fmt.Errorf("read failed: %s", err) } log = log.WithField("received", f) - ft, id, p := f.Disassemble() - // If tp of tp_id exists, attempt to forward frame to tp. - // delete tp on any failure. - - if tp, ok := c.getTp(id); ok { + // Delete tp on any failure. + if tp, ok := c.getTp(df.TpID); ok { if err := tp.HandleFrame(f); err != nil { - log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) + log.WithError(err).Warnf("Rejected [%s]: Transport closed.", df.Type) } continue } + c.delTp(df.TpID) // rm tp in case closed tp is not fully removed. // if tp does not exist, frame should be 'REQUEST'. // otherwise, handle any unexpected frames accordingly. - - c.delTp(id) // rm tp in case closed tp is not fully removed. - - switch ft { + switch df.Type { case RequestType: c.wg.Add(1) go func(log *logrus.Entry) { defer c.wg.Done() - initPK, err := c.handleRequestFrame(id, p) - if err != nil { - log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]") - if isWriteError(err) || err == ErrClientClosed { - err := c.Close() - log.WithError(err).Warn("ClosingConnection") - } - return + if initPK, err := c.handleRequestFrame(log, df.TpID, df.Pay); err != nil { + log.WithField("remoteClient", initPK).WithError(err).Warn("Rejected [REQUEST]") + } else { + log.WithField("remoteClient", initPK).Info("Accepted [REQUEST]") } - log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]") }(log) default: - log.Debugf("Ignored [%s]: No transport of given ID.", ft) - if ft != CloseType { - if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + log.Debugf("Ignored [%s]: No transport of given ID.", df.Type) + if df.Type != CloseType { + if err := writeCloseFrame(c.Conn, df.TpID, PlaceholderReason); err != nil { return err } } @@ -242,12 +233,16 @@ func (c *ClientConn) Serve(ctx context.Context) (err error) { } // DialTransport dials a transport to remote dms_client. -func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey, port uint16) (*Transport, error) { - tp, err := c.addTp(ctx, clientPK, 0, port) // TODO: Have proper local port. +func (c *ClientConn) DialTransport(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*Transport, error) { + lPort, closeCB, err := c.pm.ReserveEphemeral(ctx) if err != nil { return nil, err } - if err := tp.WriteRequest(port); err != nil { + tp, err := c.addTp(ctx, rPK, lPort, rPort, closeCB) // TODO: Have proper local port. + if err != nil { + return nil, err + } + if err := tp.WriteRequest(); err != nil { return nil, err } if err := tp.ReadAccept(ctx); err != nil { @@ -263,7 +258,7 @@ func (c *ClientConn) close() (closed bool) { } c.once.Do(func() { closed = true - c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") + c.log.WithField("remoteServer", c.srvPK).Infoln("ClosingConnection") close(c.done) c.mx.Lock() for _, tp := range c.tps { @@ -290,12 +285,11 @@ func (c *ClientConn) Close() error { return nil } -func marshalHandshakePayload(p HandshakePayload) ([]byte, error) { - return json.Marshal(p) -} - -func unmarshalHandshakePayload(b []byte) (HandshakePayload, error) { - var p HandshakePayload - err := json.Unmarshal(b, &p) - return p, err +func (c *ClientConn) isClosed() bool { + select { + case <-c.done: + return true + default: + return false + } } diff --git a/vendor/github.com/skycoin/dmsg/listener.go b/vendor/github.com/skycoin/dmsg/listener.go index 2c685f8f1..3fc6f48a4 100644 --- a/vendor/github.com/skycoin/dmsg/listener.go +++ b/vendor/github.com/skycoin/dmsg/listener.go @@ -1,36 +1,84 @@ package dmsg import ( + "fmt" "net" "sync" - - "github.com/skycoin/dmsg/cipher" ) // Listener listens for remote-initiated transports. type Listener struct { - pk cipher.PubKey - port uint16 - mx sync.Mutex // protects 'accept' + addr Addr // local listening address + accept chan *Transport - done chan struct{} - once sync.Once + mx sync.Mutex // protects 'accept' + + doneFunc func() // callback when done + done chan struct{} + once sync.Once } -func newListener(pk cipher.PubKey, port uint16) *Listener { +func newListener(addr Addr) *Listener { return &Listener{ - pk: pk, - port: port, + addr: addr, accept: make(chan *Transport, AcceptBufferSize), done: make(chan struct{}), } } +// AddCloseCallback adds a function that triggers when listener is closed. +// This should be called right after the listener is created and is not thread safe. +func (l *Listener) AddCloseCallback(cb func()) { l.doneFunc = cb } + +// IntroduceTransport handles a transport after receiving a REQUEST frame. +func (l *Listener) IntroduceTransport(tp *Transport) error { + if tp.LocalAddr() != l.addr { + return fmt.Errorf("failed to accept transport as local addresses does not match: we expected %s but got %s", + l.addr, tp.LocalAddr()) + } + + l.mx.Lock() + defer l.mx.Unlock() + + if l.isClosed() { + return ErrClientClosed + } + + select { + case <-l.done: + return ErrClientClosed + + case l.accept <- tp: + if err := tp.WriteAccept(); err != nil { + return err + } + go tp.Serve() + return nil + + default: + _ = tp.Close() //nolint:errcheck + return ErrClientAcceptMaxed + } +} + // Accept accepts a connection. func (l *Listener) Accept() (net.Conn, error) { return l.AcceptTransport() } +// AcceptTransport accepts a transport connection. +func (l *Listener) AcceptTransport() (*Transport, error) { + select { + case <-l.done: + return nil, ErrClientClosed + case tp, ok := <-l.accept: + if !ok { + return nil, ErrClientClosed + } + return tp, nil + } +} + // Close closes the listener. func (l *Listener) Close() error { if l.close() { @@ -42,6 +90,7 @@ func (l *Listener) Close() error { func (l *Listener) close() (closed bool) { l.once.Do(func() { closed = true + l.doneFunc() l.mx.Lock() defer l.mx.Unlock() @@ -69,55 +118,7 @@ func (l *Listener) isClosed() bool { } // Addr returns the listener's address. -func (l *Listener) Addr() net.Addr { - return Addr{ - PK: l.pk, - Port: l.port, - } -} - -// AcceptTransport accepts a transport connection. -func (l *Listener) AcceptTransport() (*Transport, error) { - select { - case <-l.done: - return nil, ErrClientClosed - case tp, ok := <-l.accept: - if !ok { - return nil, ErrClientClosed - } - return tp, nil - } -} +func (l *Listener) Addr() net.Addr { return l.addr } // Type returns the transport type. -func (l *Listener) Type() string { - return Type -} - -// IntroduceTransport handles a transport after receiving a REQUEST frame. -func (l *Listener) IntroduceTransport(tp *Transport) error { - l.mx.Lock() - defer l.mx.Unlock() - - if l.isClosed() { - return ErrClientClosed - } - - select { - case <-l.done: - return ErrClientClosed - - case l.accept <- tp: - if err := tp.WriteAccept(); err != nil { - return err - } - go tp.Serve() - return nil - - default: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return ErrClientAcceptMaxed - } -} +func (l *Listener) Type() string { return Type } diff --git a/vendor/github.com/skycoin/dmsg/netutil/porter.go b/vendor/github.com/skycoin/dmsg/netutil/porter.go new file mode 100644 index 000000000..fb0d2c1b2 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/netutil/porter.go @@ -0,0 +1,102 @@ +package netutil + +import ( + "context" + "sync" +) + +const ( + // PorterMinEphemeral is the default minimum ephemeral port. + PorterMinEphemeral = uint16(49152) +) + +// Porter reserves ports. +type Porter struct { + sync.RWMutex + eph uint16 // current ephemeral value + minEph uint16 // minimal ephemeral port value + ports map[uint16]interface{} +} + +// NewPorter creates a new Porter with a given minimum ephemeral port value. +func NewPorter(minEph uint16) *Porter { + ports := make(map[uint16]interface{}) + ports[0] = struct{}{} // port 0 is invalid + + return &Porter{ + eph: minEph, + minEph: minEph, + ports: ports, + } +} + +// Reserve a given port. +// It returns a boolean informing whether the port is reserved, and a function to clear the reservation. +func (p *Porter) Reserve(port uint16, v interface{}) (bool, func()) { + p.Lock() + defer p.Unlock() + + if _, ok := p.ports[port]; ok { + return false, nil + } + p.ports[port] = v + return true, p.makePortFreer(port) +} + +// ReserveEphemeral reserves a new ephemeral port. +// It returns the reserved ephemeral port, a function to clear the reservation and an error (if any). +func (p *Porter) ReserveEphemeral(ctx context.Context, v interface{}) (uint16, func(), error) { + p.Lock() + defer p.Unlock() + + for { + p.eph++ + if p.eph < p.minEph { + p.eph = p.minEph + } + if _, ok := p.ports[p.eph]; ok { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + default: + continue + } + } + p.ports[p.eph] = v + return p.eph, p.makePortFreer(p.eph), nil + } +} + +// PortValue returns the value stored under a given port. +func (p *Porter) PortValue(port uint16) (interface{}, bool) { + p.RLock() + defer p.RUnlock() + + v, ok := p.ports[port] + return v, ok +} + +// RangePortValues ranges all ports that are currently reserved. +func (p *Porter) RangePortValues(fn func(port uint16, v interface{}) (next bool)) { + p.RLock() + defer p.RUnlock() + + for port, v := range p.ports { + if next := fn(port, v); !next { + return + } + } +} + +// This returns a function that frees a given port. +// It is ensured that the function's action is only performed once. +func (p *Porter) makePortFreer(port uint16) func() { + once := new(sync.Once) + return func() { + once.Do(func() { + p.Lock() + delete(p.ports, port) + p.Unlock() + }) + } +} diff --git a/vendor/github.com/skycoin/dmsg/port_manager.go b/vendor/github.com/skycoin/dmsg/port_manager.go index 63540c701..0ab5a18e4 100644 --- a/vendor/github.com/skycoin/dmsg/port_manager.go +++ b/vendor/github.com/skycoin/dmsg/port_manager.go @@ -1,72 +1,66 @@ package dmsg import ( - "math/rand" + "context" "sync" - "time" "github.com/skycoin/dmsg/cipher" -) - -const ( - firstEphemeralPort = 49152 - lastEphemeralPort = 65535 + "github.com/skycoin/dmsg/netutil" ) // PortManager manages ports of nodes. type PortManager struct { - mu sync.RWMutex - rand *rand.Rand - listeners map[uint16]*Listener + lPK cipher.PubKey + p *netutil.Porter } -func newPortManager() *PortManager { +func newPortManager(lPK cipher.PubKey) *PortManager { return &PortManager{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - listeners: make(map[uint16]*Listener), + lPK: lPK, + p: netutil.NewPorter(netutil.PorterMinEphemeral), } } // Listener returns a listener assigned to a given port. func (pm *PortManager) Listener(port uint16) (*Listener, bool) { - pm.mu.RLock() - defer pm.mu.RUnlock() - - l, ok := pm.listeners[port] + v, ok := pm.p.PortValue(port) + if !ok { + return nil, false + } + l, ok := v.(*Listener) return l, ok } // NewListener assigns listener to port if port is available. -func (pm *PortManager) NewListener(pk cipher.PubKey, port uint16) (*Listener, bool) { - pm.mu.Lock() - defer pm.mu.Unlock() - if _, ok := pm.listeners[port]; ok { +func (pm *PortManager) NewListener(port uint16) (*Listener, bool) { + l := newListener(Addr{pm.lPK, port}) + ok, clear := pm.p.Reserve(port, l) + if !ok { return nil, false } - l := newListener(pk, port) - pm.listeners[port] = l + l.AddCloseCallback(clear) return l, true } -// RemoveListener removes listener assigned to port. -func (pm *PortManager) RemoveListener(port uint16) { - pm.mu.Lock() - defer pm.mu.Unlock() - - delete(pm.listeners, port) +// ReserveEphemeral reserves an ephemeral port. +func (pm *PortManager) ReserveEphemeral(ctx context.Context) (uint16, func(), error) { + return pm.p.ReserveEphemeral(ctx, nil) } -// NextEmptyEphemeralPort returns next random ephemeral port. -// It has a value between firstEphemeralPort and lastEphemeralPort. -func (pm *PortManager) NextEmptyEphemeralPort() uint16 { - for { - port := pm.randomEphemeralPort() - if _, ok := pm.Listener(port); !ok { - return port +// Close closes all listeners. +func (pm *PortManager) Close() error { + wg := new(sync.WaitGroup) + pm.p.RangePortValues(func(_ uint16, v interface{}) (next bool) { + l, ok := v.(*Listener) + if ok { + wg.Add(1) + go func() { + l.close() + wg.Done() + }() } - } -} - -func (pm *PortManager) randomEphemeralPort() uint16 { - return uint16(firstEphemeralPort + pm.rand.Intn(lastEphemeralPort-firstEphemeralPort)) + return true + }) + wg.Wait() + return nil } diff --git a/vendor/github.com/skycoin/dmsg/server.go b/vendor/github.com/skycoin/dmsg/server.go index ba0ee3dd3..a5bfa304c 100644 --- a/vendor/github.com/skycoin/dmsg/server.go +++ b/vendor/github.com/skycoin/dmsg/server.go @@ -19,239 +19,6 @@ import ( // ErrListenerAlreadyWrappedToNoise occurs when the provided net.Listener is already wrapped with noise.Listener var ErrListenerAlreadyWrappedToNoise = errors.New("listener is already wrapped to *noise.Listener") -// NextConn provides information on the next connection. -type NextConn struct { - conn *ServerConn - id uint16 -} - -func (r *NextConn) writeFrame(ft FrameType, p []byte) error { - if err := writeFrame(r.conn.Conn, MakeFrame(ft, r.id, p)); err != nil { - go func() { - if err := r.conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - }() - return err - } - return nil -} - -// ServerConn is a connection between a dmsg.Server and a dmsg.Client from a server's perspective. -type ServerConn struct { - log *logging.Logger - - net.Conn - remoteClient cipher.PubKey - - nextRespID uint16 - nextConns map[uint16]*NextConn - mx sync.RWMutex -} - -// NewServerConn creates a new connection from the perspective of a dms_server. -func NewServerConn(log *logging.Logger, conn net.Conn, remoteClient cipher.PubKey) *ServerConn { - return &ServerConn{ - log: log, - Conn: conn, - remoteClient: remoteClient, - nextRespID: randID(false), - nextConns: make(map[uint16]*NextConn), - } -} - -func (c *ServerConn) delNext(id uint16) { - c.mx.Lock() - delete(c.nextConns, id) - c.mx.Unlock() -} - -func (c *ServerConn) setNext(id uint16, r *NextConn) { - c.mx.Lock() - c.nextConns[id] = r - c.mx.Unlock() -} - -func (c *ServerConn) getNext(id uint16) (*NextConn, bool) { - c.mx.RLock() - r := c.nextConns[id] - c.mx.RUnlock() - return r, r != nil -} - -func (c *ServerConn) addNext(ctx context.Context, r *NextConn) (uint16, error) { - c.mx.Lock() - defer c.mx.Unlock() - - for { - if r := c.nextConns[c.nextRespID]; r == nil { - break - } - c.nextRespID += 2 - - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: - } - } - - id := c.nextRespID - c.nextRespID = id + 2 - c.nextConns[id] = r - return id, nil -} - -// PK returns the remote dms_client's public key. -func (c *ServerConn) PK() cipher.PubKey { - return c.remoteClient -} - -type getConnFunc func(pk cipher.PubKey) (*ServerConn, bool) - -// Serve handles (and forwards when necessary) incoming frames. -func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) { - log := c.log.WithField("srcClient", c.remoteClient) - - // Only manually close the underlying net.Conn when the done signal is context-initiated. - done := make(chan struct{}) - defer close(done) - go func() { - select { - case <-done: - case <-ctx.Done(): - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("failed to close underlying connection") - } - } - }() - - defer func() { - // Send CLOSE frames to all transports which are established with this dmsg.Client - // This ensures that all parties are informed about the transport closing. - c.mx.Lock() - for _, conn := range c.nextConns { - why := byte(0) - if err := conn.writeFrame(CloseType, []byte{why}); err != nil { - log.WithError(err).Warnf("failed to write frame: %s", err) - } - } - c.mx.Unlock() - - log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - }() - - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") - - err = c.writeOK() - if err != nil { - return fmt.Errorf("sending OK failed: %s", err) - } - - for { - f, err := readFrame(c.Conn) - if err != nil { - return fmt.Errorf("read failed: %s", err) - } - log := log.WithField("received", f) - - ft, id, p := f.Disassemble() - - switch ft { - case RequestType: - ctx, cancel := context.WithTimeout(ctx, TransportHandshakeTimeout) - _, why, ok := c.handleRequest(ctx, getConn, id, p) - cancel() - if !ok { - log.Debugln("FrameRejected: Erroneous request or unresponsive dstClient.") - if err := c.delChan(id, why); err != nil { - return err - } - } - log.Debugln("FrameForwarded") - - case AcceptType, FwdType, AckType, CloseType: - next, why, ok := c.forwardFrame(ft, id, p) - if !ok { - log.Debugln("FrameRejected: Failed to forward to dstClient.") - // Delete channel (and associations) on failure. - if err := c.delChan(id, why); err != nil { - return err - } - continue - } - log.Debugln("FrameForwarded") - - // On success, if Close frame, delete the associations. - if ft == CloseType { - c.delNext(id) - next.conn.delNext(next.id) - } - - default: - log.Debugln("FrameRejected: Unknown frame type.") - // Unknown frame type. - return errors.New("unknown frame of type received") - } - } -} - -func (c *ServerConn) delChan(id uint16, why byte) error { - c.delNext(id) - if err := writeCloseFrame(c.Conn, id, why); err != nil { - return fmt.Errorf("failed to write frame: %s", err) - } - return nil -} - -func (c *ServerConn) writeOK() error { - if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { - return err - } - return nil -} - -// nolint:unparam -func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { - next, ok := c.getNext(id) - if !ok { - return next, 0, false - } - if err := next.writeFrame(ft, p); err != nil { - return next, 0, false - } - return next, 0, true -} - -// nolint:unparam -func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { - payload, err := unmarshalHandshakePayload(p) - if err != nil || payload.InitPK != c.PK() { - return nil, 0, false - } - respL, ok := getLink(payload.RespPK) - if !ok { - return nil, 0, false - } - - // set next relations. - respID, err := respL.addNext(ctx, &NextConn{conn: c, id: id}) - if err != nil { - return nil, 0, false - } - next := &NextConn{conn: respL, id: respID} - c.setNext(id, next) - - // forward to responding client. - if err := next.writeFrame(RequestType, p); err != nil { - return next, 0, false - } - return next, 0, true -} - // Server represents a dms_server. type Server struct { log *logging.Logger diff --git a/vendor/github.com/skycoin/dmsg/server_conn.go b/vendor/github.com/skycoin/dmsg/server_conn.go new file mode 100644 index 000000000..a162b5102 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/server_conn.go @@ -0,0 +1,243 @@ +package dmsg + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/dmsg/cipher" +) + +// NextConn provides information on the next connection. +type NextConn struct { + conn *ServerConn + id uint16 +} + +func (r *NextConn) writeFrame(ft FrameType, p []byte) error { + if err := writeFrame(r.conn.Conn, MakeFrame(ft, r.id, p)); err != nil { + go func() { + if err := r.conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + }() + return err + } + return nil +} + +// ServerConn is a connection between a dmsg.Server and a dmsg.Client from a server's perspective. +type ServerConn struct { + log *logging.Logger + + net.Conn + remoteClient cipher.PubKey + + nextRespID uint16 + nextConns map[uint16]*NextConn + mx sync.RWMutex +} + +// NewServerConn creates a new connection from the perspective of a dms_server. +func NewServerConn(log *logging.Logger, conn net.Conn, remoteClient cipher.PubKey) *ServerConn { + return &ServerConn{ + log: log, + Conn: conn, + remoteClient: remoteClient, + nextRespID: randID(false), + nextConns: make(map[uint16]*NextConn), + } +} + +func (c *ServerConn) delNext(id uint16) { + c.mx.Lock() + delete(c.nextConns, id) + c.mx.Unlock() +} + +func (c *ServerConn) setNext(id uint16, r *NextConn) { + c.mx.Lock() + c.nextConns[id] = r + c.mx.Unlock() +} + +func (c *ServerConn) getNext(id uint16) (*NextConn, bool) { + c.mx.RLock() + r := c.nextConns[id] + c.mx.RUnlock() + return r, r != nil +} + +func (c *ServerConn) addNext(ctx context.Context, r *NextConn) (uint16, error) { + c.mx.Lock() + defer c.mx.Unlock() + + for { + if r := c.nextConns[c.nextRespID]; r == nil { + break + } + c.nextRespID += 2 + + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + } + + id := c.nextRespID + c.nextRespID = id + 2 + c.nextConns[id] = r + return id, nil +} + +// PK returns the remote dms_client's public key. +func (c *ServerConn) PK() cipher.PubKey { + return c.remoteClient +} + +type getConnFunc func(pk cipher.PubKey) (*ServerConn, bool) + +// Serve handles (and forwards when necessary) incoming frames. +func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) { + log := c.log.WithField("srcClient", c.remoteClient) + + // Only manually close the underlying net.Conn when the done signal is context-initiated. + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-done: + case <-ctx.Done(): + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("failed to close underlying connection") + } + } + }() + + defer func() { + // Send CLOSE frames to all transports which are established with this dmsg.Client + // This ensures that all parties are informed about the transport closing. + c.mx.Lock() + for _, conn := range c.nextConns { + why := byte(0) + if err := conn.writeFrame(CloseType, []byte{why}); err != nil { + log.WithError(err).Warnf("failed to write frame: %s", err) + } + } + c.mx.Unlock() + + log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + }() + + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + + err = c.writeOK() + if err != nil { + return fmt.Errorf("sending OK failed: %s", err) + } + + for { + f, df, err := readFrame(c.Conn) + if err != nil { + return fmt.Errorf("read failed: %s", err) + } + log := log.WithField("received", f) + + switch df.Type { + case RequestType: + ctx, cancel := context.WithTimeout(ctx, TransportHandshakeTimeout) + _, why, ok := c.handleRequest(ctx, getConn, df.TpID, df.Pay) + cancel() + if !ok { + log.Debugln("FrameRejected: Erroneous request or unresponsive dstClient.") + if err := c.delChan(df.TpID, why); err != nil { + return err + } + } + log.Debugln("FrameForwarded") + + case AcceptType, FwdType, AckType, CloseType: + next, why, ok := c.forwardFrame(df.Type, df.TpID, df.Pay) + if !ok { + log.Debugln("FrameRejected: Failed to forward to dstClient.") + // Delete channel (and associations) on failure. + if err := c.delChan(df.TpID, why); err != nil { + return err + } + continue + } + log.Debugln("FrameForwarded") + + // On success, if Close frame, delete the associations. + if df.Type == CloseType { + c.delNext(df.TpID) + next.conn.delNext(next.id) + } + + default: + log.Debugln("FrameRejected: Unknown frame type.") + return errors.New("unknown frame of type received") + } + } +} + +func (c *ServerConn) delChan(id uint16, why byte) error { + c.delNext(id) + if err := writeCloseFrame(c.Conn, id, why); err != nil { + return fmt.Errorf("failed to write frame: %s", err) + } + return nil +} + +func (c *ServerConn) writeOK() error { + if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { + return err + } + return nil +} + +// nolint:unparam +func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { + next, ok := c.getNext(id) + if !ok { + return next, 0, false + } + if err := next.writeFrame(ft, p); err != nil { + return next, 0, false + } + return next, 0, true +} + +// nolint:unparam +func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { + payload, err := unmarshalHandshakePayload(p) + if err != nil || payload.InitAddr.PK != c.PK() { + return nil, 0, false + } + respL, ok := getLink(payload.RespAddr.PK) + if !ok { + return nil, 0, false + } + + // set next relations. + respID, err := respL.addNext(ctx, &NextConn{conn: c, id: id}) + if err != nil { + return nil, 0, false + } + next := &NextConn{conn: respL, id: respID} + c.setNext(id, next) + + // forward to responding client. + if err := next.writeFrame(RequestType, p); err != nil { + return next, 0, false + } + return next, 0, true +} diff --git a/vendor/github.com/skycoin/dmsg/transport.go b/vendor/github.com/skycoin/dmsg/transport.go index 2b1da95a7..5a7467172 100644 --- a/vendor/github.com/skycoin/dmsg/transport.go +++ b/vendor/github.com/skycoin/dmsg/transport.go @@ -41,17 +41,16 @@ type Transport struct { bufCh chan struct{} // chan for indicating whether this is a new FWD frame bufSize int // keeps track of the total size of 'buf' bufMx sync.Mutex // protects fields responsible for handling FWD and ACK frames - rMx sync.Mutex // TODO: (WORKAROUND) concurrent reads seem problematic right now. - serving chan struct{} // chan which closes when serving begins - servingOnce sync.Once // ensures 'serving' only closes once - done chan struct{} // chan which closes when transport stops serving - doneOnce sync.Once // ensures 'done' only closes once - doneFunc func(id uint16) // contains a method to remove the transport from dmsg.Client + serving chan struct{} // chan which closes when serving begins + servingOnce sync.Once // ensures 'serving' only closes once + done chan struct{} // chan which closes when transport stops serving + doneOnce sync.Once // ensures 'done' only closes once + doneFunc func() // contains a method that triggers when dmsg.Client closes } // NewTransport creates a new dms_tp. -func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func(id uint16)) *Transport { +func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func()) *Transport { tp := &Transport{ Conn: conn, log: log, @@ -96,7 +95,7 @@ func (tp *Transport) close() (closed bool) { closed = true close(tp.done) - tp.doneFunc(tp.id) + tp.doneFunc() tp.bufMx.Lock() close(tp.bufCh) @@ -170,12 +169,11 @@ func (tp *Transport) HandleFrame(f Frame) error { } // WriteRequest writes a REQUEST frame to dmsg_server to be forwarded to associated client. -func (tp *Transport) WriteRequest(port uint16) error { +func (tp *Transport) WriteRequest() error { payload := HandshakePayload{ - Version: HandshakePayloadVersion, - InitPK: tp.local.PK, - RespPK: tp.remote.PK, - Port: port, + Version: HandshakePayloadVersion, + InitAddr: tp.local, + RespAddr: tp.remote, } payloadBytes, err := marshalHandshakePayload(payload) if err != nil { @@ -360,9 +358,6 @@ func (tp *Transport) Serve() { func (tp *Transport) Read(p []byte) (n int, err error) { <-tp.serving - tp.rMx.Lock() - defer tp.rMx.Unlock() - startRead: tp.bufMx.Lock() n, err = tp.buf.Read(p) diff --git a/vendor/github.com/skycoin/dmsg/frame.go b/vendor/github.com/skycoin/dmsg/types.go similarity index 71% rename from vendor/github.com/skycoin/dmsg/frame.go rename to vendor/github.com/skycoin/dmsg/types.go index 33b354ef9..dcaabe6db 100644 --- a/vendor/github.com/skycoin/dmsg/frame.go +++ b/vendor/github.com/skycoin/dmsg/types.go @@ -2,6 +2,7 @@ package dmsg import ( "encoding/binary" + "encoding/json" "fmt" "io" "math" @@ -18,7 +19,7 @@ const ( Type = "dmsg" // HandshakePayloadVersion contains payload version to maintain compatibility with future versions // of HandshakePayload format. - HandshakePayloadVersion = "1" + HandshakePayloadVersion = "2.0" tpBufCap = math.MaxUint16 tpBufFrameCap = math.MaxUint8 @@ -34,15 +35,43 @@ var ( AcceptBufferSize = 20 ) +// Addr implements net.Addr for dmsg addresses. +type Addr struct { + PK cipher.PubKey `json:"public_key"` + Port uint16 `json:"port"` +} + +// Network returns "dmsg" +func (Addr) Network() string { + return Type +} + +// String returns public key and port of node split by colon. +func (a Addr) String() string { + if a.Port == 0 { + return fmt.Sprintf("%s:~", a.PK) + } + return fmt.Sprintf("%s:%d", a.PK, a.Port) +} + // HandshakePayload represents format of payload sent with REQUEST frames. -// TODO(evanlinjin): Use 'dmsg.Addr' for PK:Port pair. type HandshakePayload struct { - Version string `json:"version"` // just in case the struct changes. - InitPK cipher.PubKey `json:"init_pk"` - RespPK cipher.PubKey `json:"resp_pk"` - Port uint16 `json:"port"` + Version string `json:"version"` // just in case the struct changes. + InitAddr Addr `json:"init_address"` + RespAddr Addr `json:"resp_address"` +} + +func marshalHandshakePayload(p HandshakePayload) ([]byte, error) { + return json.Marshal(p) } +func unmarshalHandshakePayload(b []byte) (HandshakePayload, error) { + var p HandshakePayload + err := json.Unmarshal(b, &p) + return p, err +} + +// determines whether the transport ID is of an initiator or responder. func isInitiatorID(tpID uint16) bool { return tpID%2 == 0 } func randID(initiator bool) uint16 { @@ -55,6 +84,7 @@ func randID(initiator bool) uint16 { } } +// serveCount records the number of dmsg.Servers connected var serveCount int64 func incrementServeCount() int64 { return atomic.AddInt64(&serveCount, 1) } @@ -133,24 +163,36 @@ func (f Frame) String() string { return fmt.Sprintf("%s", f.Type(), f.TpID(), f.PayLen(), p) } -func readFrame(r io.Reader) (Frame, error) { - f := make(Frame, headerLen) - if _, err := io.ReadFull(r, f); err != nil { - return nil, err +type disassembledFrame struct { + Type FrameType + TpID uint16 + Pay []byte +} + +// read and disassembles frame from reader +func readFrame(r io.Reader) (f Frame, df disassembledFrame, err error) { + f = make(Frame, headerLen) + if _, err = io.ReadFull(r, f); err != nil { + return } f = append(f, make([]byte, f.PayLen())...) - _, err := io.ReadFull(r, f[headerLen:]) - return f, err + if _, err = io.ReadFull(r, f[headerLen:]); err != nil { + return + } + t, id, p := f.Disassemble() + df = disassembledFrame{Type: t, TpID: id, Pay: p} + return } type writeError struct{ error } func (e *writeError) Error() string { return "write error: " + e.error.Error() } -func isWriteError(err error) bool { - _, ok := err.(*writeError) - return ok -} +// TODO(evanlinjin): Determine if this is still needed, may be useful elsewhere. +//func isWriteError(err error) bool { +// _, ok := err.(*writeError) +// return ok +//} func writeFrame(w io.Writer, f Frame) error { _, err := w.Write(f) diff --git a/vendor/github.com/stretchr/objx/.codeclimate.yml b/vendor/github.com/stretchr/objx/.codeclimate.yml new file mode 100644 index 000000000..010d4ccd5 --- /dev/null +++ b/vendor/github.com/stretchr/objx/.codeclimate.yml @@ -0,0 +1,13 @@ +engines: + gofmt: + enabled: true + golint: + enabled: true + govet: + enabled: true + +exclude_patterns: +- ".github/" +- "vendor/" +- "codegen/" +- "doc.go" diff --git a/vendor/github.com/stretchr/objx/.gitignore b/vendor/github.com/stretchr/objx/.gitignore new file mode 100644 index 000000000..ea58090bd --- /dev/null +++ b/vendor/github.com/stretchr/objx/.gitignore @@ -0,0 +1,11 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out diff --git a/vendor/github.com/stretchr/objx/.travis.yml b/vendor/github.com/stretchr/objx/.travis.yml new file mode 100644 index 000000000..a63efa59d --- /dev/null +++ b/vendor/github.com/stretchr/objx/.travis.yml @@ -0,0 +1,25 @@ +language: go +go: + - 1.8 + - 1.9 + - tip + +env: + global: + - CC_TEST_REPORTER_ID=68feaa3410049ce73e145287acbcdacc525087a30627f96f04e579e75bd71c00 + +before_script: + - curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter + - chmod +x ./cc-test-reporter + - ./cc-test-reporter before-build + +install: +- go get github.com/go-task/task/cmd/task + +script: +- task dl-deps +- task lint +- task test-coverage + +after_script: + - ./cc-test-reporter after-build --exit-code $TRAVIS_TEST_RESULT diff --git a/vendor/github.com/stretchr/objx/Gopkg.lock b/vendor/github.com/stretchr/objx/Gopkg.lock new file mode 100644 index 000000000..eebe342a9 --- /dev/null +++ b/vendor/github.com/stretchr/objx/Gopkg.lock @@ -0,0 +1,30 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + name = "github.com/davecgh/go-spew" + packages = ["spew"] + revision = "346938d642f2ec3594ed81d874461961cd0faa76" + version = "v1.1.0" + +[[projects]] + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + name = "github.com/stretchr/testify" + packages = [ + "assert", + "require" + ] + revision = "b91bfb9ebec76498946beb6af7c0230c7cc7ba6c" + version = "v1.2.0" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "2d160a7dea4ffd13c6c31dab40373822f9d78c73beba016d662bef8f7a998876" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/vendor/github.com/stretchr/objx/Gopkg.toml b/vendor/github.com/stretchr/objx/Gopkg.toml new file mode 100644 index 000000000..d70f1570b --- /dev/null +++ b/vendor/github.com/stretchr/objx/Gopkg.toml @@ -0,0 +1,8 @@ +[prune] + unused-packages = true + non-go = true + go-tests = true + +[[constraint]] + name = "github.com/stretchr/testify" + version = "~1.2.0" diff --git a/vendor/github.com/stretchr/objx/LICENSE b/vendor/github.com/stretchr/objx/LICENSE new file mode 100644 index 000000000..44d4d9d5a --- /dev/null +++ b/vendor/github.com/stretchr/objx/LICENSE @@ -0,0 +1,22 @@ +The MIT License + +Copyright (c) 2014 Stretchr, Inc. +Copyright (c) 2017-2018 objx contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/stretchr/objx/README.md b/vendor/github.com/stretchr/objx/README.md new file mode 100644 index 000000000..be5750c94 --- /dev/null +++ b/vendor/github.com/stretchr/objx/README.md @@ -0,0 +1,80 @@ +# Objx +[![Build Status](https://travis-ci.org/stretchr/objx.svg?branch=master)](https://travis-ci.org/stretchr/objx) +[![Go Report Card](https://goreportcard.com/badge/github.com/stretchr/objx)](https://goreportcard.com/report/github.com/stretchr/objx) +[![Maintainability](https://api.codeclimate.com/v1/badges/1d64bc6c8474c2074f2b/maintainability)](https://codeclimate.com/github/stretchr/objx/maintainability) +[![Test Coverage](https://api.codeclimate.com/v1/badges/1d64bc6c8474c2074f2b/test_coverage)](https://codeclimate.com/github/stretchr/objx/test_coverage) +[![Sourcegraph](https://sourcegraph.com/github.com/stretchr/objx/-/badge.svg)](https://sourcegraph.com/github.com/stretchr/objx) +[![GoDoc](https://godoc.org/github.com/stretchr/objx?status.svg)](https://godoc.org/github.com/stretchr/objx) + +Objx - Go package for dealing with maps, slices, JSON and other data. + +Get started: + +- Install Objx with [one line of code](#installation), or [update it with another](#staying-up-to-date) +- Check out the API Documentation http://godoc.org/github.com/stretchr/objx + +## Overview +Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes a powerful `Get` method (among others) that allows you to easily and quickly get access to data within the map, without having to worry too much about type assertions, missing data, default values etc. + +### Pattern +Objx uses a preditable pattern to make access data from within `map[string]interface{}` easy. Call one of the `objx.` functions to create your `objx.Map` to get going: + + m, err := objx.FromJSON(json) + +NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, the rest will be optimistic and try to figure things out without panicking. + +Use `Get` to access the value you're interested in. You can use dot and array +notation too: + + m.Get("places[0].latlng") + +Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type. + + if m.Get("code").IsStr() { // Your code... } + +Or you can just assume the type, and use one of the strong type methods to extract the real value: + + m.Get("code").Int() + +If there's no value there (or if it's the wrong type) then a default value will be returned, or you can be explicit about the default value. + + Get("code").Int(-1) + +If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating, manipulating and selecting that data. You can find out more by exploring the index below. + +### Reading data +A simple example of how to use Objx: + + // Use MustFromJSON to make an objx.Map from some JSON + m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`) + + // Get the details + name := m.Get("name").Str() + age := m.Get("age").Int() + + // Get their nickname (or use their name if they don't have one) + nickname := m.Get("nickname").Str(name) + +### Ranging +Since `objx.Map` is a `map[string]interface{}` you can treat it as such. For example, to `range` the data, do what you would expect: + + m := objx.MustFromJSON(json) + for key, value := range m { + // Your code... + } + +## Installation +To install Objx, use go get: + + go get github.com/stretchr/objx + +### Staying up to date +To update Objx to the latest version, run: + + go get -u github.com/stretchr/objx + +### Supported go versions +We support the lastest two major Go versions, which are 1.8 and 1.9 at the moment. + +## Contributing +Please feel free to submit issues, fork the repository and send pull requests! diff --git a/vendor/github.com/stretchr/objx/Taskfile.yml b/vendor/github.com/stretchr/objx/Taskfile.yml new file mode 100644 index 000000000..f8035641f --- /dev/null +++ b/vendor/github.com/stretchr/objx/Taskfile.yml @@ -0,0 +1,32 @@ +default: + deps: [test] + +dl-deps: + desc: Downloads cli dependencies + cmds: + - go get -u github.com/golang/lint/golint + - go get -u github.com/golang/dep/cmd/dep + +update-deps: + desc: Updates dependencies + cmds: + - dep ensure + - dep ensure -update + +lint: + desc: Runs golint + cmds: + - go fmt $(go list ./... | grep -v /vendor/) + - go vet $(go list ./... | grep -v /vendor/) + - golint $(ls *.go | grep -v "doc.go") + silent: true + +test: + desc: Runs go tests + cmds: + - go test -race . + +test-coverage: + desc: Runs go tests and calucates test coverage + cmds: + - go test -coverprofile=c.out . diff --git a/vendor/github.com/stretchr/objx/accessors.go b/vendor/github.com/stretchr/objx/accessors.go new file mode 100644 index 000000000..204356a22 --- /dev/null +++ b/vendor/github.com/stretchr/objx/accessors.go @@ -0,0 +1,148 @@ +package objx + +import ( + "regexp" + "strconv" + "strings" +) + +// arrayAccesRegexString is the regex used to extract the array number +// from the access path +const arrayAccesRegexString = `^(.+)\[([0-9]+)\]$` + +// arrayAccesRegex is the compiled arrayAccesRegexString +var arrayAccesRegex = regexp.MustCompile(arrayAccesRegexString) + +// Get gets the value using the specified selector and +// returns it inside a new Obj object. +// +// If it cannot find the value, Get will return a nil +// value inside an instance of Obj. +// +// Get can only operate directly on map[string]interface{} and []interface. +// +// Example +// +// To access the title of the third chapter of the second book, do: +// +// o.Get("books[1].chapters[2].title") +func (m Map) Get(selector string) *Value { + rawObj := access(m, selector, nil, false) + return &Value{data: rawObj} +} + +// Set sets the value using the specified selector and +// returns the object on which Set was called. +// +// Set can only operate directly on map[string]interface{} and []interface +// +// Example +// +// To set the title of the third chapter of the second book, do: +// +// o.Set("books[1].chapters[2].title","Time to Go") +func (m Map) Set(selector string, value interface{}) Map { + access(m, selector, value, true) + return m +} + +// access accesses the object using the selector and performs the +// appropriate action. +func access(current, selector, value interface{}, isSet bool) interface{} { + switch selector.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + if array, ok := current.([]interface{}); ok { + index := intFromInterface(selector) + if index >= len(array) { + return nil + } + return array[index] + } + return nil + + case string: + selStr := selector.(string) + selSegs := strings.SplitN(selStr, PathSeparator, 2) + thisSel := selSegs[0] + index := -1 + var err error + + if strings.Contains(thisSel, "[") { + arrayMatches := arrayAccesRegex.FindStringSubmatch(thisSel) + if len(arrayMatches) > 0 { + // Get the key into the map + thisSel = arrayMatches[1] + + // Get the index into the array at the key + index, err = strconv.Atoi(arrayMatches[2]) + + if err != nil { + // This should never happen. If it does, something has gone + // seriously wrong. Panic. + panic("objx: Array index is not an integer. Must use array[int].") + } + } + } + if curMap, ok := current.(Map); ok { + current = map[string]interface{}(curMap) + } + // get the object in question + switch current.(type) { + case map[string]interface{}: + curMSI := current.(map[string]interface{}) + if len(selSegs) <= 1 && isSet { + curMSI[thisSel] = value + return nil + } + current = curMSI[thisSel] + default: + current = nil + } + // do we need to access the item of an array? + if index > -1 { + if array, ok := current.([]interface{}); ok { + if index < len(array) { + current = array[index] + } else { + current = nil + } + } + } + if len(selSegs) > 1 { + current = access(current, selSegs[1], value, isSet) + } + } + return current +} + +// intFromInterface converts an interface object to the largest +// representation of an unsigned integer using a type switch and +// assertions +func intFromInterface(selector interface{}) int { + var value int + switch selector.(type) { + case int: + value = selector.(int) + case int8: + value = int(selector.(int8)) + case int16: + value = int(selector.(int16)) + case int32: + value = int(selector.(int32)) + case int64: + value = int(selector.(int64)) + case uint: + value = int(selector.(uint)) + case uint8: + value = int(selector.(uint8)) + case uint16: + value = int(selector.(uint16)) + case uint32: + value = int(selector.(uint32)) + case uint64: + value = int(selector.(uint64)) + default: + return 0 + } + return value +} diff --git a/vendor/github.com/stretchr/objx/constants.go b/vendor/github.com/stretchr/objx/constants.go new file mode 100644 index 000000000..f9eb42a25 --- /dev/null +++ b/vendor/github.com/stretchr/objx/constants.go @@ -0,0 +1,13 @@ +package objx + +const ( + // PathSeparator is the character used to separate the elements + // of the keypath. + // + // For example, `location.address.city` + PathSeparator string = "." + + // SignatureSeparator is the character that is used to + // separate the Base64 string from the security signature. + SignatureSeparator = "_" +) diff --git a/vendor/github.com/stretchr/objx/conversions.go b/vendor/github.com/stretchr/objx/conversions.go new file mode 100644 index 000000000..5e020f310 --- /dev/null +++ b/vendor/github.com/stretchr/objx/conversions.go @@ -0,0 +1,108 @@ +package objx + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/url" +) + +// JSON converts the contained object to a JSON string +// representation +func (m Map) JSON() (string, error) { + result, err := json.Marshal(m) + if err != nil { + err = errors.New("objx: JSON encode failed with: " + err.Error()) + } + return string(result), err +} + +// MustJSON converts the contained object to a JSON string +// representation and panics if there is an error +func (m Map) MustJSON() string { + result, err := m.JSON() + if err != nil { + panic(err.Error()) + } + return result +} + +// Base64 converts the contained object to a Base64 string +// representation of the JSON string representation +func (m Map) Base64() (string, error) { + var buf bytes.Buffer + + jsonData, err := m.JSON() + if err != nil { + return "", err + } + + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + _, err = encoder.Write([]byte(jsonData)) + if err != nil { + return "", err + } + _ = encoder.Close() + + return buf.String(), nil +} + +// MustBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and panics +// if there is an error +func (m Map) MustBase64() string { + result, err := m.Base64() + if err != nil { + panic(err.Error()) + } + return result +} + +// SignedBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and signs it +// using the provided key. +func (m Map) SignedBase64(key string) (string, error) { + base64, err := m.Base64() + if err != nil { + return "", err + } + + sig := HashWithKey(base64, key) + return base64 + SignatureSeparator + sig, nil +} + +// MustSignedBase64 converts the contained object to a Base64 string +// representation of the JSON string representation and signs it +// using the provided key and panics if there is an error +func (m Map) MustSignedBase64(key string) string { + result, err := m.SignedBase64(key) + if err != nil { + panic(err.Error()) + } + return result +} + +/* + URL Query + ------------------------------------------------ +*/ + +// URLValues creates a url.Values object from an Obj. This +// function requires that the wrapped object be a map[string]interface{} +func (m Map) URLValues() url.Values { + vals := make(url.Values) + for k, v := range m { + //TODO: can this be done without sprintf? + vals.Set(k, fmt.Sprintf("%v", v)) + } + return vals +} + +// URLQuery gets an encoded URL query representing the given +// Obj. This function requires that the wrapped object be a +// map[string]interface{} +func (m Map) URLQuery() (string, error) { + return m.URLValues().Encode(), nil +} diff --git a/vendor/github.com/stretchr/objx/doc.go b/vendor/github.com/stretchr/objx/doc.go new file mode 100644 index 000000000..6d6af1a83 --- /dev/null +++ b/vendor/github.com/stretchr/objx/doc.go @@ -0,0 +1,66 @@ +/* +Objx - Go package for dealing with maps, slices, JSON and other data. + +Overview + +Objx provides the `objx.Map` type, which is a `map[string]interface{}` that exposes +a powerful `Get` method (among others) that allows you to easily and quickly get +access to data within the map, without having to worry too much about type assertions, +missing data, default values etc. + +Pattern + +Objx uses a preditable pattern to make access data from within `map[string]interface{}` easy. +Call one of the `objx.` functions to create your `objx.Map` to get going: + + m, err := objx.FromJSON(json) + +NOTE: Any methods or functions with the `Must` prefix will panic if something goes wrong, +the rest will be optimistic and try to figure things out without panicking. + +Use `Get` to access the value you're interested in. You can use dot and array +notation too: + + m.Get("places[0].latlng") + +Once you have sought the `Value` you're interested in, you can use the `Is*` methods to determine its type. + + if m.Get("code").IsStr() { // Your code... } + +Or you can just assume the type, and use one of the strong type methods to extract the real value: + + m.Get("code").Int() + +If there's no value there (or if it's the wrong type) then a default value will be returned, +or you can be explicit about the default value. + + Get("code").Int(-1) + +If you're dealing with a slice of data as a value, Objx provides many useful methods for iterating, +manipulating and selecting that data. You can find out more by exploring the index below. + +Reading data + +A simple example of how to use Objx: + + // Use MustFromJSON to make an objx.Map from some JSON + m := objx.MustFromJSON(`{"name": "Mat", "age": 30}`) + + // Get the details + name := m.Get("name").Str() + age := m.Get("age").Int() + + // Get their nickname (or use their name if they don't have one) + nickname := m.Get("nickname").Str(name) + +Ranging + +Since `objx.Map` is a `map[string]interface{}` you can treat it as such. +For example, to `range` the data, do what you would expect: + + m := objx.MustFromJSON(json) + for key, value := range m { + // Your code... + } +*/ +package objx diff --git a/vendor/github.com/stretchr/objx/map.go b/vendor/github.com/stretchr/objx/map.go new file mode 100644 index 000000000..406bc8926 --- /dev/null +++ b/vendor/github.com/stretchr/objx/map.go @@ -0,0 +1,190 @@ +package objx + +import ( + "encoding/base64" + "encoding/json" + "errors" + "io/ioutil" + "net/url" + "strings" +) + +// MSIConvertable is an interface that defines methods for converting your +// custom types to a map[string]interface{} representation. +type MSIConvertable interface { + // MSI gets a map[string]interface{} (msi) representing the + // object. + MSI() map[string]interface{} +} + +// Map provides extended functionality for working with +// untyped data, in particular map[string]interface (msi). +type Map map[string]interface{} + +// Value returns the internal value instance +func (m Map) Value() *Value { + return &Value{data: m} +} + +// Nil represents a nil Map. +var Nil = New(nil) + +// New creates a new Map containing the map[string]interface{} in the data argument. +// If the data argument is not a map[string]interface, New attempts to call the +// MSI() method on the MSIConvertable interface to create one. +func New(data interface{}) Map { + if _, ok := data.(map[string]interface{}); !ok { + if converter, ok := data.(MSIConvertable); ok { + data = converter.MSI() + } else { + return nil + } + } + return Map(data.(map[string]interface{})) +} + +// MSI creates a map[string]interface{} and puts it inside a new Map. +// +// The arguments follow a key, value pattern. +// +// +// Returns nil if any key argument is non-string or if there are an odd number of arguments. +// +// Example +// +// To easily create Maps: +// +// m := objx.MSI("name", "Mat", "age", 29, "subobj", objx.MSI("active", true)) +// +// // creates an Map equivalent to +// m := objx.Map{"name": "Mat", "age": 29, "subobj": objx.Map{"active": true}} +func MSI(keyAndValuePairs ...interface{}) Map { + newMap := Map{} + keyAndValuePairsLen := len(keyAndValuePairs) + if keyAndValuePairsLen%2 != 0 { + return nil + } + for i := 0; i < keyAndValuePairsLen; i = i + 2 { + key := keyAndValuePairs[i] + value := keyAndValuePairs[i+1] + + // make sure the key is a string + keyString, keyStringOK := key.(string) + if !keyStringOK { + return nil + } + newMap[keyString] = value + } + return newMap +} + +// ****** Conversion Constructors + +// MustFromJSON creates a new Map containing the data specified in the +// jsonString. +// +// Panics if the JSON is invalid. +func MustFromJSON(jsonString string) Map { + o, err := FromJSON(jsonString) + if err != nil { + panic("objx: MustFromJSON failed with error: " + err.Error()) + } + return o +} + +// FromJSON creates a new Map containing the data specified in the +// jsonString. +// +// Returns an error if the JSON is invalid. +func FromJSON(jsonString string) (Map, error) { + var data interface{} + err := json.Unmarshal([]byte(jsonString), &data) + if err != nil { + return Nil, err + } + return New(data), nil +} + +// FromBase64 creates a new Obj containing the data specified +// in the Base64 string. +// +// The string is an encoded JSON string returned by Base64 +func FromBase64(base64String string) (Map, error) { + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(base64String)) + decoded, err := ioutil.ReadAll(decoder) + if err != nil { + return nil, err + } + return FromJSON(string(decoded)) +} + +// MustFromBase64 creates a new Obj containing the data specified +// in the Base64 string and panics if there is an error. +// +// The string is an encoded JSON string returned by Base64 +func MustFromBase64(base64String string) Map { + result, err := FromBase64(base64String) + if err != nil { + panic("objx: MustFromBase64 failed with error: " + err.Error()) + } + return result +} + +// FromSignedBase64 creates a new Obj containing the data specified +// in the Base64 string. +// +// The string is an encoded JSON string returned by SignedBase64 +func FromSignedBase64(base64String, key string) (Map, error) { + parts := strings.Split(base64String, SignatureSeparator) + if len(parts) != 2 { + return nil, errors.New("objx: Signed base64 string is malformed") + } + + sig := HashWithKey(parts[0], key) + if parts[1] != sig { + return nil, errors.New("objx: Signature for base64 data does not match") + } + return FromBase64(parts[0]) +} + +// MustFromSignedBase64 creates a new Obj containing the data specified +// in the Base64 string and panics if there is an error. +// +// The string is an encoded JSON string returned by Base64 +func MustFromSignedBase64(base64String, key string) Map { + result, err := FromSignedBase64(base64String, key) + if err != nil { + panic("objx: MustFromSignedBase64 failed with error: " + err.Error()) + } + return result +} + +// FromURLQuery generates a new Obj by parsing the specified +// query. +// +// For queries with multiple values, the first value is selected. +func FromURLQuery(query string) (Map, error) { + vals, err := url.ParseQuery(query) + if err != nil { + return nil, err + } + m := Map{} + for k, vals := range vals { + m[k] = vals[0] + } + return m, nil +} + +// MustFromURLQuery generates a new Obj by parsing the specified +// query. +// +// For queries with multiple values, the first value is selected. +// +// Panics if it encounters an error +func MustFromURLQuery(query string) Map { + o, err := FromURLQuery(query) + if err != nil { + panic("objx: MustFromURLQuery failed with error: " + err.Error()) + } + return o +} diff --git a/vendor/github.com/stretchr/objx/mutations.go b/vendor/github.com/stretchr/objx/mutations.go new file mode 100644 index 000000000..c3400a3f7 --- /dev/null +++ b/vendor/github.com/stretchr/objx/mutations.go @@ -0,0 +1,77 @@ +package objx + +// Exclude returns a new Map with the keys in the specified []string +// excluded. +func (m Map) Exclude(exclude []string) Map { + excluded := make(Map) + for k, v := range m { + if !contains(exclude, k) { + excluded[k] = v + } + } + return excluded +} + +// Copy creates a shallow copy of the Obj. +func (m Map) Copy() Map { + copied := Map{} + for k, v := range m { + copied[k] = v + } + return copied +} + +// Merge blends the specified map with a copy of this map and returns the result. +// +// Keys that appear in both will be selected from the specified map. +// This method requires that the wrapped object be a map[string]interface{} +func (m Map) Merge(merge Map) Map { + return m.Copy().MergeHere(merge) +} + +// MergeHere blends the specified map with this map and returns the current map. +// +// Keys that appear in both will be selected from the specified map. The original map +// will be modified. This method requires that +// the wrapped object be a map[string]interface{} +func (m Map) MergeHere(merge Map) Map { + for k, v := range merge { + m[k] = v + } + return m +} + +// Transform builds a new Obj giving the transformer a chance +// to change the keys and values as it goes. This method requires that +// the wrapped object be a map[string]interface{} +func (m Map) Transform(transformer func(key string, value interface{}) (string, interface{})) Map { + newMap := Map{} + for k, v := range m { + modifiedKey, modifiedVal := transformer(k, v) + newMap[modifiedKey] = modifiedVal + } + return newMap +} + +// TransformKeys builds a new map using the specified key mapping. +// +// Unspecified keys will be unaltered. +// This method requires that the wrapped object be a map[string]interface{} +func (m Map) TransformKeys(mapping map[string]string) Map { + return m.Transform(func(key string, value interface{}) (string, interface{}) { + if newKey, ok := mapping[key]; ok { + return newKey, value + } + return key, value + }) +} + +// Checks if a string slice contains a string +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} diff --git a/vendor/github.com/stretchr/objx/security.go b/vendor/github.com/stretchr/objx/security.go new file mode 100644 index 000000000..692be8e2a --- /dev/null +++ b/vendor/github.com/stretchr/objx/security.go @@ -0,0 +1,12 @@ +package objx + +import ( + "crypto/sha1" + "encoding/hex" +) + +// HashWithKey hashes the specified string using the security key +func HashWithKey(data, key string) string { + d := sha1.Sum([]byte(data + ":" + key)) + return hex.EncodeToString(d[:]) +} diff --git a/vendor/github.com/stretchr/objx/tests.go b/vendor/github.com/stretchr/objx/tests.go new file mode 100644 index 000000000..d9e0b479a --- /dev/null +++ b/vendor/github.com/stretchr/objx/tests.go @@ -0,0 +1,17 @@ +package objx + +// Has gets whether there is something at the specified selector +// or not. +// +// If m is nil, Has will always return false. +func (m Map) Has(selector string) bool { + if m == nil { + return false + } + return !m.Get(selector).IsNil() +} + +// IsNil gets whether the data is nil or not. +func (v *Value) IsNil() bool { + return v == nil || v.data == nil +} diff --git a/vendor/github.com/stretchr/objx/type_specific_codegen.go b/vendor/github.com/stretchr/objx/type_specific_codegen.go new file mode 100644 index 000000000..202a91f8c --- /dev/null +++ b/vendor/github.com/stretchr/objx/type_specific_codegen.go @@ -0,0 +1,2501 @@ +package objx + +/* + Inter (interface{} and []interface{}) +*/ + +// Inter gets the value as a interface{}, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Inter(optionalDefault ...interface{}) interface{} { + if s, ok := v.data.(interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInter gets the value as a interface{}. +// +// Panics if the object is not a interface{}. +func (v *Value) MustInter() interface{} { + return v.data.(interface{}) +} + +// InterSlice gets the value as a []interface{}, returns the optionalDefault +// value or nil if the value is not a []interface{}. +func (v *Value) InterSlice(optionalDefault ...[]interface{}) []interface{} { + if s, ok := v.data.([]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInterSlice gets the value as a []interface{}. +// +// Panics if the object is not a []interface{}. +func (v *Value) MustInterSlice() []interface{} { + return v.data.([]interface{}) +} + +// IsInter gets whether the object contained is a interface{} or not. +func (v *Value) IsInter() bool { + _, ok := v.data.(interface{}) + return ok +} + +// IsInterSlice gets whether the object contained is a []interface{} or not. +func (v *Value) IsInterSlice() bool { + _, ok := v.data.([]interface{}) + return ok +} + +// EachInter calls the specified callback for each object +// in the []interface{}. +// +// Panics if the object is the wrong type. +func (v *Value) EachInter(callback func(int, interface{}) bool) *Value { + for index, val := range v.MustInterSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInter uses the specified decider function to select items +// from the []interface{}. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInter(decider func(int, interface{}) bool) *Value { + var selected []interface{} + v.EachInter(func(index int, val interface{}) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInter uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]interface{}. +func (v *Value) GroupInter(grouper func(int, interface{}) string) *Value { + groups := make(map[string][]interface{}) + v.EachInter(func(index int, val interface{}) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]interface{}, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInter uses the specified function to replace each interface{}s +// by iterating each item. The data in the returned result will be a +// []interface{} containing the replaced items. +func (v *Value) ReplaceInter(replacer func(int, interface{}) interface{}) *Value { + arr := v.MustInterSlice() + replaced := make([]interface{}, len(arr)) + v.EachInter(func(index int, val interface{}) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInter uses the specified collector function to collect a value +// for each of the interface{}s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInter(collector func(int, interface{}) interface{}) *Value { + arr := v.MustInterSlice() + collected := make([]interface{}, len(arr)) + v.EachInter(func(index int, val interface{}) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + MSI (map[string]interface{} and []map[string]interface{}) +*/ + +// MSI gets the value as a map[string]interface{}, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) MSI(optionalDefault ...map[string]interface{}) map[string]interface{} { + if s, ok := v.data.(map[string]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustMSI gets the value as a map[string]interface{}. +// +// Panics if the object is not a map[string]interface{}. +func (v *Value) MustMSI() map[string]interface{} { + return v.data.(map[string]interface{}) +} + +// MSISlice gets the value as a []map[string]interface{}, returns the optionalDefault +// value or nil if the value is not a []map[string]interface{}. +func (v *Value) MSISlice(optionalDefault ...[]map[string]interface{}) []map[string]interface{} { + if s, ok := v.data.([]map[string]interface{}); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustMSISlice gets the value as a []map[string]interface{}. +// +// Panics if the object is not a []map[string]interface{}. +func (v *Value) MustMSISlice() []map[string]interface{} { + return v.data.([]map[string]interface{}) +} + +// IsMSI gets whether the object contained is a map[string]interface{} or not. +func (v *Value) IsMSI() bool { + _, ok := v.data.(map[string]interface{}) + return ok +} + +// IsMSISlice gets whether the object contained is a []map[string]interface{} or not. +func (v *Value) IsMSISlice() bool { + _, ok := v.data.([]map[string]interface{}) + return ok +} + +// EachMSI calls the specified callback for each object +// in the []map[string]interface{}. +// +// Panics if the object is the wrong type. +func (v *Value) EachMSI(callback func(int, map[string]interface{}) bool) *Value { + for index, val := range v.MustMSISlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereMSI uses the specified decider function to select items +// from the []map[string]interface{}. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereMSI(decider func(int, map[string]interface{}) bool) *Value { + var selected []map[string]interface{} + v.EachMSI(func(index int, val map[string]interface{}) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupMSI uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]map[string]interface{}. +func (v *Value) GroupMSI(grouper func(int, map[string]interface{}) string) *Value { + groups := make(map[string][]map[string]interface{}) + v.EachMSI(func(index int, val map[string]interface{}) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]map[string]interface{}, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceMSI uses the specified function to replace each map[string]interface{}s +// by iterating each item. The data in the returned result will be a +// []map[string]interface{} containing the replaced items. +func (v *Value) ReplaceMSI(replacer func(int, map[string]interface{}) map[string]interface{}) *Value { + arr := v.MustMSISlice() + replaced := make([]map[string]interface{}, len(arr)) + v.EachMSI(func(index int, val map[string]interface{}) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectMSI uses the specified collector function to collect a value +// for each of the map[string]interface{}s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectMSI(collector func(int, map[string]interface{}) interface{}) *Value { + arr := v.MustMSISlice() + collected := make([]interface{}, len(arr)) + v.EachMSI(func(index int, val map[string]interface{}) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + ObjxMap ((Map) and [](Map)) +*/ + +// ObjxMap gets the value as a (Map), returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) ObjxMap(optionalDefault ...(Map)) Map { + if s, ok := v.data.((Map)); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return New(nil) +} + +// MustObjxMap gets the value as a (Map). +// +// Panics if the object is not a (Map). +func (v *Value) MustObjxMap() Map { + return v.data.((Map)) +} + +// ObjxMapSlice gets the value as a [](Map), returns the optionalDefault +// value or nil if the value is not a [](Map). +func (v *Value) ObjxMapSlice(optionalDefault ...[](Map)) [](Map) { + if s, ok := v.data.([](Map)); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustObjxMapSlice gets the value as a [](Map). +// +// Panics if the object is not a [](Map). +func (v *Value) MustObjxMapSlice() [](Map) { + return v.data.([](Map)) +} + +// IsObjxMap gets whether the object contained is a (Map) or not. +func (v *Value) IsObjxMap() bool { + _, ok := v.data.((Map)) + return ok +} + +// IsObjxMapSlice gets whether the object contained is a [](Map) or not. +func (v *Value) IsObjxMapSlice() bool { + _, ok := v.data.([](Map)) + return ok +} + +// EachObjxMap calls the specified callback for each object +// in the [](Map). +// +// Panics if the object is the wrong type. +func (v *Value) EachObjxMap(callback func(int, Map) bool) *Value { + for index, val := range v.MustObjxMapSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereObjxMap uses the specified decider function to select items +// from the [](Map). The object contained in the result will contain +// only the selected items. +func (v *Value) WhereObjxMap(decider func(int, Map) bool) *Value { + var selected [](Map) + v.EachObjxMap(func(index int, val Map) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupObjxMap uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][](Map). +func (v *Value) GroupObjxMap(grouper func(int, Map) string) *Value { + groups := make(map[string][](Map)) + v.EachObjxMap(func(index int, val Map) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([](Map), 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceObjxMap uses the specified function to replace each (Map)s +// by iterating each item. The data in the returned result will be a +// [](Map) containing the replaced items. +func (v *Value) ReplaceObjxMap(replacer func(int, Map) Map) *Value { + arr := v.MustObjxMapSlice() + replaced := make([](Map), len(arr)) + v.EachObjxMap(func(index int, val Map) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectObjxMap uses the specified collector function to collect a value +// for each of the (Map)s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectObjxMap(collector func(int, Map) interface{}) *Value { + arr := v.MustObjxMapSlice() + collected := make([]interface{}, len(arr)) + v.EachObjxMap(func(index int, val Map) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Bool (bool and []bool) +*/ + +// Bool gets the value as a bool, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Bool(optionalDefault ...bool) bool { + if s, ok := v.data.(bool); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return false +} + +// MustBool gets the value as a bool. +// +// Panics if the object is not a bool. +func (v *Value) MustBool() bool { + return v.data.(bool) +} + +// BoolSlice gets the value as a []bool, returns the optionalDefault +// value or nil if the value is not a []bool. +func (v *Value) BoolSlice(optionalDefault ...[]bool) []bool { + if s, ok := v.data.([]bool); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustBoolSlice gets the value as a []bool. +// +// Panics if the object is not a []bool. +func (v *Value) MustBoolSlice() []bool { + return v.data.([]bool) +} + +// IsBool gets whether the object contained is a bool or not. +func (v *Value) IsBool() bool { + _, ok := v.data.(bool) + return ok +} + +// IsBoolSlice gets whether the object contained is a []bool or not. +func (v *Value) IsBoolSlice() bool { + _, ok := v.data.([]bool) + return ok +} + +// EachBool calls the specified callback for each object +// in the []bool. +// +// Panics if the object is the wrong type. +func (v *Value) EachBool(callback func(int, bool) bool) *Value { + for index, val := range v.MustBoolSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereBool uses the specified decider function to select items +// from the []bool. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereBool(decider func(int, bool) bool) *Value { + var selected []bool + v.EachBool(func(index int, val bool) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupBool uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]bool. +func (v *Value) GroupBool(grouper func(int, bool) string) *Value { + groups := make(map[string][]bool) + v.EachBool(func(index int, val bool) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]bool, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceBool uses the specified function to replace each bools +// by iterating each item. The data in the returned result will be a +// []bool containing the replaced items. +func (v *Value) ReplaceBool(replacer func(int, bool) bool) *Value { + arr := v.MustBoolSlice() + replaced := make([]bool, len(arr)) + v.EachBool(func(index int, val bool) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectBool uses the specified collector function to collect a value +// for each of the bools in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectBool(collector func(int, bool) interface{}) *Value { + arr := v.MustBoolSlice() + collected := make([]interface{}, len(arr)) + v.EachBool(func(index int, val bool) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Str (string and []string) +*/ + +// Str gets the value as a string, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Str(optionalDefault ...string) string { + if s, ok := v.data.(string); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return "" +} + +// MustStr gets the value as a string. +// +// Panics if the object is not a string. +func (v *Value) MustStr() string { + return v.data.(string) +} + +// StrSlice gets the value as a []string, returns the optionalDefault +// value or nil if the value is not a []string. +func (v *Value) StrSlice(optionalDefault ...[]string) []string { + if s, ok := v.data.([]string); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustStrSlice gets the value as a []string. +// +// Panics if the object is not a []string. +func (v *Value) MustStrSlice() []string { + return v.data.([]string) +} + +// IsStr gets whether the object contained is a string or not. +func (v *Value) IsStr() bool { + _, ok := v.data.(string) + return ok +} + +// IsStrSlice gets whether the object contained is a []string or not. +func (v *Value) IsStrSlice() bool { + _, ok := v.data.([]string) + return ok +} + +// EachStr calls the specified callback for each object +// in the []string. +// +// Panics if the object is the wrong type. +func (v *Value) EachStr(callback func(int, string) bool) *Value { + for index, val := range v.MustStrSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereStr uses the specified decider function to select items +// from the []string. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereStr(decider func(int, string) bool) *Value { + var selected []string + v.EachStr(func(index int, val string) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupStr uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]string. +func (v *Value) GroupStr(grouper func(int, string) string) *Value { + groups := make(map[string][]string) + v.EachStr(func(index int, val string) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]string, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceStr uses the specified function to replace each strings +// by iterating each item. The data in the returned result will be a +// []string containing the replaced items. +func (v *Value) ReplaceStr(replacer func(int, string) string) *Value { + arr := v.MustStrSlice() + replaced := make([]string, len(arr)) + v.EachStr(func(index int, val string) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectStr uses the specified collector function to collect a value +// for each of the strings in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectStr(collector func(int, string) interface{}) *Value { + arr := v.MustStrSlice() + collected := make([]interface{}, len(arr)) + v.EachStr(func(index int, val string) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int (int and []int) +*/ + +// Int gets the value as a int, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int(optionalDefault ...int) int { + if s, ok := v.data.(int); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt gets the value as a int. +// +// Panics if the object is not a int. +func (v *Value) MustInt() int { + return v.data.(int) +} + +// IntSlice gets the value as a []int, returns the optionalDefault +// value or nil if the value is not a []int. +func (v *Value) IntSlice(optionalDefault ...[]int) []int { + if s, ok := v.data.([]int); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustIntSlice gets the value as a []int. +// +// Panics if the object is not a []int. +func (v *Value) MustIntSlice() []int { + return v.data.([]int) +} + +// IsInt gets whether the object contained is a int or not. +func (v *Value) IsInt() bool { + _, ok := v.data.(int) + return ok +} + +// IsIntSlice gets whether the object contained is a []int or not. +func (v *Value) IsIntSlice() bool { + _, ok := v.data.([]int) + return ok +} + +// EachInt calls the specified callback for each object +// in the []int. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt(callback func(int, int) bool) *Value { + for index, val := range v.MustIntSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt uses the specified decider function to select items +// from the []int. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt(decider func(int, int) bool) *Value { + var selected []int + v.EachInt(func(index int, val int) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int. +func (v *Value) GroupInt(grouper func(int, int) string) *Value { + groups := make(map[string][]int) + v.EachInt(func(index int, val int) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt uses the specified function to replace each ints +// by iterating each item. The data in the returned result will be a +// []int containing the replaced items. +func (v *Value) ReplaceInt(replacer func(int, int) int) *Value { + arr := v.MustIntSlice() + replaced := make([]int, len(arr)) + v.EachInt(func(index int, val int) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt uses the specified collector function to collect a value +// for each of the ints in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt(collector func(int, int) interface{}) *Value { + arr := v.MustIntSlice() + collected := make([]interface{}, len(arr)) + v.EachInt(func(index int, val int) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int8 (int8 and []int8) +*/ + +// Int8 gets the value as a int8, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int8(optionalDefault ...int8) int8 { + if s, ok := v.data.(int8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt8 gets the value as a int8. +// +// Panics if the object is not a int8. +func (v *Value) MustInt8() int8 { + return v.data.(int8) +} + +// Int8Slice gets the value as a []int8, returns the optionalDefault +// value or nil if the value is not a []int8. +func (v *Value) Int8Slice(optionalDefault ...[]int8) []int8 { + if s, ok := v.data.([]int8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt8Slice gets the value as a []int8. +// +// Panics if the object is not a []int8. +func (v *Value) MustInt8Slice() []int8 { + return v.data.([]int8) +} + +// IsInt8 gets whether the object contained is a int8 or not. +func (v *Value) IsInt8() bool { + _, ok := v.data.(int8) + return ok +} + +// IsInt8Slice gets whether the object contained is a []int8 or not. +func (v *Value) IsInt8Slice() bool { + _, ok := v.data.([]int8) + return ok +} + +// EachInt8 calls the specified callback for each object +// in the []int8. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt8(callback func(int, int8) bool) *Value { + for index, val := range v.MustInt8Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt8 uses the specified decider function to select items +// from the []int8. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt8(decider func(int, int8) bool) *Value { + var selected []int8 + v.EachInt8(func(index int, val int8) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt8 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int8. +func (v *Value) GroupInt8(grouper func(int, int8) string) *Value { + groups := make(map[string][]int8) + v.EachInt8(func(index int, val int8) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int8, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt8 uses the specified function to replace each int8s +// by iterating each item. The data in the returned result will be a +// []int8 containing the replaced items. +func (v *Value) ReplaceInt8(replacer func(int, int8) int8) *Value { + arr := v.MustInt8Slice() + replaced := make([]int8, len(arr)) + v.EachInt8(func(index int, val int8) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt8 uses the specified collector function to collect a value +// for each of the int8s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt8(collector func(int, int8) interface{}) *Value { + arr := v.MustInt8Slice() + collected := make([]interface{}, len(arr)) + v.EachInt8(func(index int, val int8) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int16 (int16 and []int16) +*/ + +// Int16 gets the value as a int16, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int16(optionalDefault ...int16) int16 { + if s, ok := v.data.(int16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt16 gets the value as a int16. +// +// Panics if the object is not a int16. +func (v *Value) MustInt16() int16 { + return v.data.(int16) +} + +// Int16Slice gets the value as a []int16, returns the optionalDefault +// value or nil if the value is not a []int16. +func (v *Value) Int16Slice(optionalDefault ...[]int16) []int16 { + if s, ok := v.data.([]int16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt16Slice gets the value as a []int16. +// +// Panics if the object is not a []int16. +func (v *Value) MustInt16Slice() []int16 { + return v.data.([]int16) +} + +// IsInt16 gets whether the object contained is a int16 or not. +func (v *Value) IsInt16() bool { + _, ok := v.data.(int16) + return ok +} + +// IsInt16Slice gets whether the object contained is a []int16 or not. +func (v *Value) IsInt16Slice() bool { + _, ok := v.data.([]int16) + return ok +} + +// EachInt16 calls the specified callback for each object +// in the []int16. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt16(callback func(int, int16) bool) *Value { + for index, val := range v.MustInt16Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt16 uses the specified decider function to select items +// from the []int16. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt16(decider func(int, int16) bool) *Value { + var selected []int16 + v.EachInt16(func(index int, val int16) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt16 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int16. +func (v *Value) GroupInt16(grouper func(int, int16) string) *Value { + groups := make(map[string][]int16) + v.EachInt16(func(index int, val int16) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int16, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt16 uses the specified function to replace each int16s +// by iterating each item. The data in the returned result will be a +// []int16 containing the replaced items. +func (v *Value) ReplaceInt16(replacer func(int, int16) int16) *Value { + arr := v.MustInt16Slice() + replaced := make([]int16, len(arr)) + v.EachInt16(func(index int, val int16) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt16 uses the specified collector function to collect a value +// for each of the int16s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt16(collector func(int, int16) interface{}) *Value { + arr := v.MustInt16Slice() + collected := make([]interface{}, len(arr)) + v.EachInt16(func(index int, val int16) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int32 (int32 and []int32) +*/ + +// Int32 gets the value as a int32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int32(optionalDefault ...int32) int32 { + if s, ok := v.data.(int32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt32 gets the value as a int32. +// +// Panics if the object is not a int32. +func (v *Value) MustInt32() int32 { + return v.data.(int32) +} + +// Int32Slice gets the value as a []int32, returns the optionalDefault +// value or nil if the value is not a []int32. +func (v *Value) Int32Slice(optionalDefault ...[]int32) []int32 { + if s, ok := v.data.([]int32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt32Slice gets the value as a []int32. +// +// Panics if the object is not a []int32. +func (v *Value) MustInt32Slice() []int32 { + return v.data.([]int32) +} + +// IsInt32 gets whether the object contained is a int32 or not. +func (v *Value) IsInt32() bool { + _, ok := v.data.(int32) + return ok +} + +// IsInt32Slice gets whether the object contained is a []int32 or not. +func (v *Value) IsInt32Slice() bool { + _, ok := v.data.([]int32) + return ok +} + +// EachInt32 calls the specified callback for each object +// in the []int32. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt32(callback func(int, int32) bool) *Value { + for index, val := range v.MustInt32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt32 uses the specified decider function to select items +// from the []int32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt32(decider func(int, int32) bool) *Value { + var selected []int32 + v.EachInt32(func(index int, val int32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int32. +func (v *Value) GroupInt32(grouper func(int, int32) string) *Value { + groups := make(map[string][]int32) + v.EachInt32(func(index int, val int32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt32 uses the specified function to replace each int32s +// by iterating each item. The data in the returned result will be a +// []int32 containing the replaced items. +func (v *Value) ReplaceInt32(replacer func(int, int32) int32) *Value { + arr := v.MustInt32Slice() + replaced := make([]int32, len(arr)) + v.EachInt32(func(index int, val int32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt32 uses the specified collector function to collect a value +// for each of the int32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt32(collector func(int, int32) interface{}) *Value { + arr := v.MustInt32Slice() + collected := make([]interface{}, len(arr)) + v.EachInt32(func(index int, val int32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Int64 (int64 and []int64) +*/ + +// Int64 gets the value as a int64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Int64(optionalDefault ...int64) int64 { + if s, ok := v.data.(int64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustInt64 gets the value as a int64. +// +// Panics if the object is not a int64. +func (v *Value) MustInt64() int64 { + return v.data.(int64) +} + +// Int64Slice gets the value as a []int64, returns the optionalDefault +// value or nil if the value is not a []int64. +func (v *Value) Int64Slice(optionalDefault ...[]int64) []int64 { + if s, ok := v.data.([]int64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustInt64Slice gets the value as a []int64. +// +// Panics if the object is not a []int64. +func (v *Value) MustInt64Slice() []int64 { + return v.data.([]int64) +} + +// IsInt64 gets whether the object contained is a int64 or not. +func (v *Value) IsInt64() bool { + _, ok := v.data.(int64) + return ok +} + +// IsInt64Slice gets whether the object contained is a []int64 or not. +func (v *Value) IsInt64Slice() bool { + _, ok := v.data.([]int64) + return ok +} + +// EachInt64 calls the specified callback for each object +// in the []int64. +// +// Panics if the object is the wrong type. +func (v *Value) EachInt64(callback func(int, int64) bool) *Value { + for index, val := range v.MustInt64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereInt64 uses the specified decider function to select items +// from the []int64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereInt64(decider func(int, int64) bool) *Value { + var selected []int64 + v.EachInt64(func(index int, val int64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupInt64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]int64. +func (v *Value) GroupInt64(grouper func(int, int64) string) *Value { + groups := make(map[string][]int64) + v.EachInt64(func(index int, val int64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]int64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceInt64 uses the specified function to replace each int64s +// by iterating each item. The data in the returned result will be a +// []int64 containing the replaced items. +func (v *Value) ReplaceInt64(replacer func(int, int64) int64) *Value { + arr := v.MustInt64Slice() + replaced := make([]int64, len(arr)) + v.EachInt64(func(index int, val int64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectInt64 uses the specified collector function to collect a value +// for each of the int64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectInt64(collector func(int, int64) interface{}) *Value { + arr := v.MustInt64Slice() + collected := make([]interface{}, len(arr)) + v.EachInt64(func(index int, val int64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint (uint and []uint) +*/ + +// Uint gets the value as a uint, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint(optionalDefault ...uint) uint { + if s, ok := v.data.(uint); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint gets the value as a uint. +// +// Panics if the object is not a uint. +func (v *Value) MustUint() uint { + return v.data.(uint) +} + +// UintSlice gets the value as a []uint, returns the optionalDefault +// value or nil if the value is not a []uint. +func (v *Value) UintSlice(optionalDefault ...[]uint) []uint { + if s, ok := v.data.([]uint); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUintSlice gets the value as a []uint. +// +// Panics if the object is not a []uint. +func (v *Value) MustUintSlice() []uint { + return v.data.([]uint) +} + +// IsUint gets whether the object contained is a uint or not. +func (v *Value) IsUint() bool { + _, ok := v.data.(uint) + return ok +} + +// IsUintSlice gets whether the object contained is a []uint or not. +func (v *Value) IsUintSlice() bool { + _, ok := v.data.([]uint) + return ok +} + +// EachUint calls the specified callback for each object +// in the []uint. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint(callback func(int, uint) bool) *Value { + for index, val := range v.MustUintSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint uses the specified decider function to select items +// from the []uint. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint(decider func(int, uint) bool) *Value { + var selected []uint + v.EachUint(func(index int, val uint) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint. +func (v *Value) GroupUint(grouper func(int, uint) string) *Value { + groups := make(map[string][]uint) + v.EachUint(func(index int, val uint) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint uses the specified function to replace each uints +// by iterating each item. The data in the returned result will be a +// []uint containing the replaced items. +func (v *Value) ReplaceUint(replacer func(int, uint) uint) *Value { + arr := v.MustUintSlice() + replaced := make([]uint, len(arr)) + v.EachUint(func(index int, val uint) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint uses the specified collector function to collect a value +// for each of the uints in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint(collector func(int, uint) interface{}) *Value { + arr := v.MustUintSlice() + collected := make([]interface{}, len(arr)) + v.EachUint(func(index int, val uint) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint8 (uint8 and []uint8) +*/ + +// Uint8 gets the value as a uint8, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint8(optionalDefault ...uint8) uint8 { + if s, ok := v.data.(uint8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint8 gets the value as a uint8. +// +// Panics if the object is not a uint8. +func (v *Value) MustUint8() uint8 { + return v.data.(uint8) +} + +// Uint8Slice gets the value as a []uint8, returns the optionalDefault +// value or nil if the value is not a []uint8. +func (v *Value) Uint8Slice(optionalDefault ...[]uint8) []uint8 { + if s, ok := v.data.([]uint8); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint8Slice gets the value as a []uint8. +// +// Panics if the object is not a []uint8. +func (v *Value) MustUint8Slice() []uint8 { + return v.data.([]uint8) +} + +// IsUint8 gets whether the object contained is a uint8 or not. +func (v *Value) IsUint8() bool { + _, ok := v.data.(uint8) + return ok +} + +// IsUint8Slice gets whether the object contained is a []uint8 or not. +func (v *Value) IsUint8Slice() bool { + _, ok := v.data.([]uint8) + return ok +} + +// EachUint8 calls the specified callback for each object +// in the []uint8. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint8(callback func(int, uint8) bool) *Value { + for index, val := range v.MustUint8Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint8 uses the specified decider function to select items +// from the []uint8. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint8(decider func(int, uint8) bool) *Value { + var selected []uint8 + v.EachUint8(func(index int, val uint8) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint8 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint8. +func (v *Value) GroupUint8(grouper func(int, uint8) string) *Value { + groups := make(map[string][]uint8) + v.EachUint8(func(index int, val uint8) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint8, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint8 uses the specified function to replace each uint8s +// by iterating each item. The data in the returned result will be a +// []uint8 containing the replaced items. +func (v *Value) ReplaceUint8(replacer func(int, uint8) uint8) *Value { + arr := v.MustUint8Slice() + replaced := make([]uint8, len(arr)) + v.EachUint8(func(index int, val uint8) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint8 uses the specified collector function to collect a value +// for each of the uint8s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint8(collector func(int, uint8) interface{}) *Value { + arr := v.MustUint8Slice() + collected := make([]interface{}, len(arr)) + v.EachUint8(func(index int, val uint8) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint16 (uint16 and []uint16) +*/ + +// Uint16 gets the value as a uint16, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint16(optionalDefault ...uint16) uint16 { + if s, ok := v.data.(uint16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint16 gets the value as a uint16. +// +// Panics if the object is not a uint16. +func (v *Value) MustUint16() uint16 { + return v.data.(uint16) +} + +// Uint16Slice gets the value as a []uint16, returns the optionalDefault +// value or nil if the value is not a []uint16. +func (v *Value) Uint16Slice(optionalDefault ...[]uint16) []uint16 { + if s, ok := v.data.([]uint16); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint16Slice gets the value as a []uint16. +// +// Panics if the object is not a []uint16. +func (v *Value) MustUint16Slice() []uint16 { + return v.data.([]uint16) +} + +// IsUint16 gets whether the object contained is a uint16 or not. +func (v *Value) IsUint16() bool { + _, ok := v.data.(uint16) + return ok +} + +// IsUint16Slice gets whether the object contained is a []uint16 or not. +func (v *Value) IsUint16Slice() bool { + _, ok := v.data.([]uint16) + return ok +} + +// EachUint16 calls the specified callback for each object +// in the []uint16. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint16(callback func(int, uint16) bool) *Value { + for index, val := range v.MustUint16Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint16 uses the specified decider function to select items +// from the []uint16. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint16(decider func(int, uint16) bool) *Value { + var selected []uint16 + v.EachUint16(func(index int, val uint16) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint16 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint16. +func (v *Value) GroupUint16(grouper func(int, uint16) string) *Value { + groups := make(map[string][]uint16) + v.EachUint16(func(index int, val uint16) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint16, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint16 uses the specified function to replace each uint16s +// by iterating each item. The data in the returned result will be a +// []uint16 containing the replaced items. +func (v *Value) ReplaceUint16(replacer func(int, uint16) uint16) *Value { + arr := v.MustUint16Slice() + replaced := make([]uint16, len(arr)) + v.EachUint16(func(index int, val uint16) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint16 uses the specified collector function to collect a value +// for each of the uint16s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint16(collector func(int, uint16) interface{}) *Value { + arr := v.MustUint16Slice() + collected := make([]interface{}, len(arr)) + v.EachUint16(func(index int, val uint16) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint32 (uint32 and []uint32) +*/ + +// Uint32 gets the value as a uint32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint32(optionalDefault ...uint32) uint32 { + if s, ok := v.data.(uint32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint32 gets the value as a uint32. +// +// Panics if the object is not a uint32. +func (v *Value) MustUint32() uint32 { + return v.data.(uint32) +} + +// Uint32Slice gets the value as a []uint32, returns the optionalDefault +// value or nil if the value is not a []uint32. +func (v *Value) Uint32Slice(optionalDefault ...[]uint32) []uint32 { + if s, ok := v.data.([]uint32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint32Slice gets the value as a []uint32. +// +// Panics if the object is not a []uint32. +func (v *Value) MustUint32Slice() []uint32 { + return v.data.([]uint32) +} + +// IsUint32 gets whether the object contained is a uint32 or not. +func (v *Value) IsUint32() bool { + _, ok := v.data.(uint32) + return ok +} + +// IsUint32Slice gets whether the object contained is a []uint32 or not. +func (v *Value) IsUint32Slice() bool { + _, ok := v.data.([]uint32) + return ok +} + +// EachUint32 calls the specified callback for each object +// in the []uint32. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint32(callback func(int, uint32) bool) *Value { + for index, val := range v.MustUint32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint32 uses the specified decider function to select items +// from the []uint32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint32(decider func(int, uint32) bool) *Value { + var selected []uint32 + v.EachUint32(func(index int, val uint32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint32. +func (v *Value) GroupUint32(grouper func(int, uint32) string) *Value { + groups := make(map[string][]uint32) + v.EachUint32(func(index int, val uint32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint32 uses the specified function to replace each uint32s +// by iterating each item. The data in the returned result will be a +// []uint32 containing the replaced items. +func (v *Value) ReplaceUint32(replacer func(int, uint32) uint32) *Value { + arr := v.MustUint32Slice() + replaced := make([]uint32, len(arr)) + v.EachUint32(func(index int, val uint32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint32 uses the specified collector function to collect a value +// for each of the uint32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint32(collector func(int, uint32) interface{}) *Value { + arr := v.MustUint32Slice() + collected := make([]interface{}, len(arr)) + v.EachUint32(func(index int, val uint32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uint64 (uint64 and []uint64) +*/ + +// Uint64 gets the value as a uint64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uint64(optionalDefault ...uint64) uint64 { + if s, ok := v.data.(uint64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUint64 gets the value as a uint64. +// +// Panics if the object is not a uint64. +func (v *Value) MustUint64() uint64 { + return v.data.(uint64) +} + +// Uint64Slice gets the value as a []uint64, returns the optionalDefault +// value or nil if the value is not a []uint64. +func (v *Value) Uint64Slice(optionalDefault ...[]uint64) []uint64 { + if s, ok := v.data.([]uint64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUint64Slice gets the value as a []uint64. +// +// Panics if the object is not a []uint64. +func (v *Value) MustUint64Slice() []uint64 { + return v.data.([]uint64) +} + +// IsUint64 gets whether the object contained is a uint64 or not. +func (v *Value) IsUint64() bool { + _, ok := v.data.(uint64) + return ok +} + +// IsUint64Slice gets whether the object contained is a []uint64 or not. +func (v *Value) IsUint64Slice() bool { + _, ok := v.data.([]uint64) + return ok +} + +// EachUint64 calls the specified callback for each object +// in the []uint64. +// +// Panics if the object is the wrong type. +func (v *Value) EachUint64(callback func(int, uint64) bool) *Value { + for index, val := range v.MustUint64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUint64 uses the specified decider function to select items +// from the []uint64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUint64(decider func(int, uint64) bool) *Value { + var selected []uint64 + v.EachUint64(func(index int, val uint64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUint64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uint64. +func (v *Value) GroupUint64(grouper func(int, uint64) string) *Value { + groups := make(map[string][]uint64) + v.EachUint64(func(index int, val uint64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uint64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUint64 uses the specified function to replace each uint64s +// by iterating each item. The data in the returned result will be a +// []uint64 containing the replaced items. +func (v *Value) ReplaceUint64(replacer func(int, uint64) uint64) *Value { + arr := v.MustUint64Slice() + replaced := make([]uint64, len(arr)) + v.EachUint64(func(index int, val uint64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUint64 uses the specified collector function to collect a value +// for each of the uint64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUint64(collector func(int, uint64) interface{}) *Value { + arr := v.MustUint64Slice() + collected := make([]interface{}, len(arr)) + v.EachUint64(func(index int, val uint64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Uintptr (uintptr and []uintptr) +*/ + +// Uintptr gets the value as a uintptr, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Uintptr(optionalDefault ...uintptr) uintptr { + if s, ok := v.data.(uintptr); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustUintptr gets the value as a uintptr. +// +// Panics if the object is not a uintptr. +func (v *Value) MustUintptr() uintptr { + return v.data.(uintptr) +} + +// UintptrSlice gets the value as a []uintptr, returns the optionalDefault +// value or nil if the value is not a []uintptr. +func (v *Value) UintptrSlice(optionalDefault ...[]uintptr) []uintptr { + if s, ok := v.data.([]uintptr); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustUintptrSlice gets the value as a []uintptr. +// +// Panics if the object is not a []uintptr. +func (v *Value) MustUintptrSlice() []uintptr { + return v.data.([]uintptr) +} + +// IsUintptr gets whether the object contained is a uintptr or not. +func (v *Value) IsUintptr() bool { + _, ok := v.data.(uintptr) + return ok +} + +// IsUintptrSlice gets whether the object contained is a []uintptr or not. +func (v *Value) IsUintptrSlice() bool { + _, ok := v.data.([]uintptr) + return ok +} + +// EachUintptr calls the specified callback for each object +// in the []uintptr. +// +// Panics if the object is the wrong type. +func (v *Value) EachUintptr(callback func(int, uintptr) bool) *Value { + for index, val := range v.MustUintptrSlice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereUintptr uses the specified decider function to select items +// from the []uintptr. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereUintptr(decider func(int, uintptr) bool) *Value { + var selected []uintptr + v.EachUintptr(func(index int, val uintptr) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupUintptr uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]uintptr. +func (v *Value) GroupUintptr(grouper func(int, uintptr) string) *Value { + groups := make(map[string][]uintptr) + v.EachUintptr(func(index int, val uintptr) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]uintptr, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceUintptr uses the specified function to replace each uintptrs +// by iterating each item. The data in the returned result will be a +// []uintptr containing the replaced items. +func (v *Value) ReplaceUintptr(replacer func(int, uintptr) uintptr) *Value { + arr := v.MustUintptrSlice() + replaced := make([]uintptr, len(arr)) + v.EachUintptr(func(index int, val uintptr) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectUintptr uses the specified collector function to collect a value +// for each of the uintptrs in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectUintptr(collector func(int, uintptr) interface{}) *Value { + arr := v.MustUintptrSlice() + collected := make([]interface{}, len(arr)) + v.EachUintptr(func(index int, val uintptr) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Float32 (float32 and []float32) +*/ + +// Float32 gets the value as a float32, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Float32(optionalDefault ...float32) float32 { + if s, ok := v.data.(float32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustFloat32 gets the value as a float32. +// +// Panics if the object is not a float32. +func (v *Value) MustFloat32() float32 { + return v.data.(float32) +} + +// Float32Slice gets the value as a []float32, returns the optionalDefault +// value or nil if the value is not a []float32. +func (v *Value) Float32Slice(optionalDefault ...[]float32) []float32 { + if s, ok := v.data.([]float32); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustFloat32Slice gets the value as a []float32. +// +// Panics if the object is not a []float32. +func (v *Value) MustFloat32Slice() []float32 { + return v.data.([]float32) +} + +// IsFloat32 gets whether the object contained is a float32 or not. +func (v *Value) IsFloat32() bool { + _, ok := v.data.(float32) + return ok +} + +// IsFloat32Slice gets whether the object contained is a []float32 or not. +func (v *Value) IsFloat32Slice() bool { + _, ok := v.data.([]float32) + return ok +} + +// EachFloat32 calls the specified callback for each object +// in the []float32. +// +// Panics if the object is the wrong type. +func (v *Value) EachFloat32(callback func(int, float32) bool) *Value { + for index, val := range v.MustFloat32Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereFloat32 uses the specified decider function to select items +// from the []float32. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereFloat32(decider func(int, float32) bool) *Value { + var selected []float32 + v.EachFloat32(func(index int, val float32) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupFloat32 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]float32. +func (v *Value) GroupFloat32(grouper func(int, float32) string) *Value { + groups := make(map[string][]float32) + v.EachFloat32(func(index int, val float32) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]float32, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceFloat32 uses the specified function to replace each float32s +// by iterating each item. The data in the returned result will be a +// []float32 containing the replaced items. +func (v *Value) ReplaceFloat32(replacer func(int, float32) float32) *Value { + arr := v.MustFloat32Slice() + replaced := make([]float32, len(arr)) + v.EachFloat32(func(index int, val float32) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectFloat32 uses the specified collector function to collect a value +// for each of the float32s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectFloat32(collector func(int, float32) interface{}) *Value { + arr := v.MustFloat32Slice() + collected := make([]interface{}, len(arr)) + v.EachFloat32(func(index int, val float32) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Float64 (float64 and []float64) +*/ + +// Float64 gets the value as a float64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Float64(optionalDefault ...float64) float64 { + if s, ok := v.data.(float64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustFloat64 gets the value as a float64. +// +// Panics if the object is not a float64. +func (v *Value) MustFloat64() float64 { + return v.data.(float64) +} + +// Float64Slice gets the value as a []float64, returns the optionalDefault +// value or nil if the value is not a []float64. +func (v *Value) Float64Slice(optionalDefault ...[]float64) []float64 { + if s, ok := v.data.([]float64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustFloat64Slice gets the value as a []float64. +// +// Panics if the object is not a []float64. +func (v *Value) MustFloat64Slice() []float64 { + return v.data.([]float64) +} + +// IsFloat64 gets whether the object contained is a float64 or not. +func (v *Value) IsFloat64() bool { + _, ok := v.data.(float64) + return ok +} + +// IsFloat64Slice gets whether the object contained is a []float64 or not. +func (v *Value) IsFloat64Slice() bool { + _, ok := v.data.([]float64) + return ok +} + +// EachFloat64 calls the specified callback for each object +// in the []float64. +// +// Panics if the object is the wrong type. +func (v *Value) EachFloat64(callback func(int, float64) bool) *Value { + for index, val := range v.MustFloat64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereFloat64 uses the specified decider function to select items +// from the []float64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereFloat64(decider func(int, float64) bool) *Value { + var selected []float64 + v.EachFloat64(func(index int, val float64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupFloat64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]float64. +func (v *Value) GroupFloat64(grouper func(int, float64) string) *Value { + groups := make(map[string][]float64) + v.EachFloat64(func(index int, val float64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]float64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceFloat64 uses the specified function to replace each float64s +// by iterating each item. The data in the returned result will be a +// []float64 containing the replaced items. +func (v *Value) ReplaceFloat64(replacer func(int, float64) float64) *Value { + arr := v.MustFloat64Slice() + replaced := make([]float64, len(arr)) + v.EachFloat64(func(index int, val float64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectFloat64 uses the specified collector function to collect a value +// for each of the float64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectFloat64(collector func(int, float64) interface{}) *Value { + arr := v.MustFloat64Slice() + collected := make([]interface{}, len(arr)) + v.EachFloat64(func(index int, val float64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Complex64 (complex64 and []complex64) +*/ + +// Complex64 gets the value as a complex64, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Complex64(optionalDefault ...complex64) complex64 { + if s, ok := v.data.(complex64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustComplex64 gets the value as a complex64. +// +// Panics if the object is not a complex64. +func (v *Value) MustComplex64() complex64 { + return v.data.(complex64) +} + +// Complex64Slice gets the value as a []complex64, returns the optionalDefault +// value or nil if the value is not a []complex64. +func (v *Value) Complex64Slice(optionalDefault ...[]complex64) []complex64 { + if s, ok := v.data.([]complex64); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustComplex64Slice gets the value as a []complex64. +// +// Panics if the object is not a []complex64. +func (v *Value) MustComplex64Slice() []complex64 { + return v.data.([]complex64) +} + +// IsComplex64 gets whether the object contained is a complex64 or not. +func (v *Value) IsComplex64() bool { + _, ok := v.data.(complex64) + return ok +} + +// IsComplex64Slice gets whether the object contained is a []complex64 or not. +func (v *Value) IsComplex64Slice() bool { + _, ok := v.data.([]complex64) + return ok +} + +// EachComplex64 calls the specified callback for each object +// in the []complex64. +// +// Panics if the object is the wrong type. +func (v *Value) EachComplex64(callback func(int, complex64) bool) *Value { + for index, val := range v.MustComplex64Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereComplex64 uses the specified decider function to select items +// from the []complex64. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereComplex64(decider func(int, complex64) bool) *Value { + var selected []complex64 + v.EachComplex64(func(index int, val complex64) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupComplex64 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]complex64. +func (v *Value) GroupComplex64(grouper func(int, complex64) string) *Value { + groups := make(map[string][]complex64) + v.EachComplex64(func(index int, val complex64) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]complex64, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceComplex64 uses the specified function to replace each complex64s +// by iterating each item. The data in the returned result will be a +// []complex64 containing the replaced items. +func (v *Value) ReplaceComplex64(replacer func(int, complex64) complex64) *Value { + arr := v.MustComplex64Slice() + replaced := make([]complex64, len(arr)) + v.EachComplex64(func(index int, val complex64) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectComplex64 uses the specified collector function to collect a value +// for each of the complex64s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectComplex64(collector func(int, complex64) interface{}) *Value { + arr := v.MustComplex64Slice() + collected := make([]interface{}, len(arr)) + v.EachComplex64(func(index int, val complex64) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} + +/* + Complex128 (complex128 and []complex128) +*/ + +// Complex128 gets the value as a complex128, returns the optionalDefault +// value or a system default object if the value is the wrong type. +func (v *Value) Complex128(optionalDefault ...complex128) complex128 { + if s, ok := v.data.(complex128); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return 0 +} + +// MustComplex128 gets the value as a complex128. +// +// Panics if the object is not a complex128. +func (v *Value) MustComplex128() complex128 { + return v.data.(complex128) +} + +// Complex128Slice gets the value as a []complex128, returns the optionalDefault +// value or nil if the value is not a []complex128. +func (v *Value) Complex128Slice(optionalDefault ...[]complex128) []complex128 { + if s, ok := v.data.([]complex128); ok { + return s + } + if len(optionalDefault) == 1 { + return optionalDefault[0] + } + return nil +} + +// MustComplex128Slice gets the value as a []complex128. +// +// Panics if the object is not a []complex128. +func (v *Value) MustComplex128Slice() []complex128 { + return v.data.([]complex128) +} + +// IsComplex128 gets whether the object contained is a complex128 or not. +func (v *Value) IsComplex128() bool { + _, ok := v.data.(complex128) + return ok +} + +// IsComplex128Slice gets whether the object contained is a []complex128 or not. +func (v *Value) IsComplex128Slice() bool { + _, ok := v.data.([]complex128) + return ok +} + +// EachComplex128 calls the specified callback for each object +// in the []complex128. +// +// Panics if the object is the wrong type. +func (v *Value) EachComplex128(callback func(int, complex128) bool) *Value { + for index, val := range v.MustComplex128Slice() { + carryon := callback(index, val) + if !carryon { + break + } + } + return v +} + +// WhereComplex128 uses the specified decider function to select items +// from the []complex128. The object contained in the result will contain +// only the selected items. +func (v *Value) WhereComplex128(decider func(int, complex128) bool) *Value { + var selected []complex128 + v.EachComplex128(func(index int, val complex128) bool { + shouldSelect := decider(index, val) + if !shouldSelect { + selected = append(selected, val) + } + return true + }) + return &Value{data: selected} +} + +// GroupComplex128 uses the specified grouper function to group the items +// keyed by the return of the grouper. The object contained in the +// result will contain a map[string][]complex128. +func (v *Value) GroupComplex128(grouper func(int, complex128) string) *Value { + groups := make(map[string][]complex128) + v.EachComplex128(func(index int, val complex128) bool { + group := grouper(index, val) + if _, ok := groups[group]; !ok { + groups[group] = make([]complex128, 0) + } + groups[group] = append(groups[group], val) + return true + }) + return &Value{data: groups} +} + +// ReplaceComplex128 uses the specified function to replace each complex128s +// by iterating each item. The data in the returned result will be a +// []complex128 containing the replaced items. +func (v *Value) ReplaceComplex128(replacer func(int, complex128) complex128) *Value { + arr := v.MustComplex128Slice() + replaced := make([]complex128, len(arr)) + v.EachComplex128(func(index int, val complex128) bool { + replaced[index] = replacer(index, val) + return true + }) + return &Value{data: replaced} +} + +// CollectComplex128 uses the specified collector function to collect a value +// for each of the complex128s in the slice. The data returned will be a +// []interface{}. +func (v *Value) CollectComplex128(collector func(int, complex128) interface{}) *Value { + arr := v.MustComplex128Slice() + collected := make([]interface{}, len(arr)) + v.EachComplex128(func(index int, val complex128) bool { + collected[index] = collector(index, val) + return true + }) + return &Value{data: collected} +} diff --git a/vendor/github.com/stretchr/objx/value.go b/vendor/github.com/stretchr/objx/value.go new file mode 100644 index 000000000..e4b4a1433 --- /dev/null +++ b/vendor/github.com/stretchr/objx/value.go @@ -0,0 +1,53 @@ +package objx + +import ( + "fmt" + "strconv" +) + +// Value provides methods for extracting interface{} data in various +// types. +type Value struct { + // data contains the raw data being managed by this Value + data interface{} +} + +// Data returns the raw data contained by this Value +func (v *Value) Data() interface{} { + return v.data +} + +// String returns the value always as a string +func (v *Value) String() string { + switch { + case v.IsStr(): + return v.Str() + case v.IsBool(): + return strconv.FormatBool(v.Bool()) + case v.IsFloat32(): + return strconv.FormatFloat(float64(v.Float32()), 'f', -1, 32) + case v.IsFloat64(): + return strconv.FormatFloat(v.Float64(), 'f', -1, 64) + case v.IsInt(): + return strconv.FormatInt(int64(v.Int()), 10) + case v.IsInt8(): + return strconv.FormatInt(int64(v.Int8()), 10) + case v.IsInt16(): + return strconv.FormatInt(int64(v.Int16()), 10) + case v.IsInt32(): + return strconv.FormatInt(int64(v.Int32()), 10) + case v.IsInt64(): + return strconv.FormatInt(v.Int64(), 10) + case v.IsUint(): + return strconv.FormatUint(uint64(v.Uint()), 10) + case v.IsUint8(): + return strconv.FormatUint(uint64(v.Uint8()), 10) + case v.IsUint16(): + return strconv.FormatUint(uint64(v.Uint16()), 10) + case v.IsUint32(): + return strconv.FormatUint(uint64(v.Uint32()), 10) + case v.IsUint64(): + return strconv.FormatUint(v.Uint64(), 10) + } + return fmt.Sprintf("%#v", v.Data()) +} diff --git a/vendor/github.com/stretchr/testify/mock/doc.go b/vendor/github.com/stretchr/testify/mock/doc.go new file mode 100644 index 000000000..7324128ef --- /dev/null +++ b/vendor/github.com/stretchr/testify/mock/doc.go @@ -0,0 +1,44 @@ +// Package mock provides a system by which it is possible to mock your objects +// and verify calls are happening as expected. +// +// Example Usage +// +// The mock package provides an object, Mock, that tracks activity on another object. It is usually +// embedded into a test object as shown below: +// +// type MyTestObject struct { +// // add a Mock object instance +// mock.Mock +// +// // other fields go here as normal +// } +// +// When implementing the methods of an interface, you wire your functions up +// to call the Mock.Called(args...) method, and return the appropriate values. +// +// For example, to mock a method that saves the name and age of a person and returns +// the year of their birth or an error, you might write this: +// +// func (o *MyTestObject) SavePersonDetails(firstname, lastname string, age int) (int, error) { +// args := o.Called(firstname, lastname, age) +// return args.Int(0), args.Error(1) +// } +// +// The Int, Error and Bool methods are examples of strongly typed getters that take the argument +// index position. Given this argument list: +// +// (12, true, "Something") +// +// You could read them out strongly typed like this: +// +// args.Int(0) +// args.Bool(1) +// args.String(2) +// +// For objects of your own type, use the generic Arguments.Get(index) method and make a type assertion: +// +// return args.Get(0).(*MyObject), args.Get(1).(*AnotherObjectOfMine) +// +// This may cause a panic if the object you are getting is nil (the type assertion will fail), in those +// cases you should check for nil first. +package mock diff --git a/vendor/github.com/stretchr/testify/mock/mock.go b/vendor/github.com/stretchr/testify/mock/mock.go new file mode 100644 index 000000000..b5288af5b --- /dev/null +++ b/vendor/github.com/stretchr/testify/mock/mock.go @@ -0,0 +1,894 @@ +package mock + +import ( + "errors" + "fmt" + "reflect" + "regexp" + "runtime" + "strings" + "sync" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/pmezard/go-difflib/difflib" + "github.com/stretchr/objx" + "github.com/stretchr/testify/assert" +) + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Logf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) + FailNow() +} + +/* + Call +*/ + +// Call represents a method call and is used for setting expectations, +// as well as recording activity. +type Call struct { + Parent *Mock + + // The name of the method that was or will be called. + Method string + + // Holds the arguments of the method. + Arguments Arguments + + // Holds the arguments that should be returned when + // this method is called. + ReturnArguments Arguments + + // Holds the caller info for the On() call + callerInfo []string + + // The number of times to return the return arguments when setting + // expectations. 0 means to always return the value. + Repeatability int + + // Amount of times this call has been called + totalCalls int + + // Call to this method can be optional + optional bool + + // Holds a channel that will be used to block the Return until it either + // receives a message or is closed. nil means it returns immediately. + WaitFor <-chan time.Time + + waitTime time.Duration + + // Holds a handler used to manipulate arguments content that are passed by + // reference. It's useful when mocking methods such as unmarshalers or + // decoders. + RunFn func(Arguments) +} + +func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call { + return &Call{ + Parent: parent, + Method: methodName, + Arguments: methodArguments, + ReturnArguments: make([]interface{}, 0), + callerInfo: callerInfo, + Repeatability: 0, + WaitFor: nil, + RunFn: nil, + } +} + +func (c *Call) lock() { + c.Parent.mutex.Lock() +} + +func (c *Call) unlock() { + c.Parent.mutex.Unlock() +} + +// Return specifies the return arguments for the expectation. +// +// Mock.On("DoSomething").Return(errors.New("failed")) +func (c *Call) Return(returnArguments ...interface{}) *Call { + c.lock() + defer c.unlock() + + c.ReturnArguments = returnArguments + + return c +} + +// Once indicates that that the mock should only return the value once. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once() +func (c *Call) Once() *Call { + return c.Times(1) +} + +// Twice indicates that that the mock should only return the value twice. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice() +func (c *Call) Twice() *Call { + return c.Times(2) +} + +// Times indicates that that the mock should only return the indicated number +// of times. +// +// Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5) +func (c *Call) Times(i int) *Call { + c.lock() + defer c.unlock() + c.Repeatability = i + return c +} + +// WaitUntil sets the channel that will block the mock's return until its closed +// or a message is received. +// +// Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second)) +func (c *Call) WaitUntil(w <-chan time.Time) *Call { + c.lock() + defer c.unlock() + c.WaitFor = w + return c +} + +// After sets how long to block until the call returns +// +// Mock.On("MyMethod", arg1, arg2).After(time.Second) +func (c *Call) After(d time.Duration) *Call { + c.lock() + defer c.unlock() + c.waitTime = d + return c +} + +// Run sets a handler to be called before returning. It can be used when +// mocking a method such as unmarshalers that takes a pointer to a struct and +// sets properties in such struct +// +// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(func(args Arguments) { +// arg := args.Get(0).(*map[string]interface{}) +// arg["foo"] = "bar" +// }) +func (c *Call) Run(fn func(args Arguments)) *Call { + c.lock() + defer c.unlock() + c.RunFn = fn + return c +} + +// Maybe allows the method call to be optional. Not calling an optional method +// will not cause an error while asserting expectations +func (c *Call) Maybe() *Call { + c.lock() + defer c.unlock() + c.optional = true + return c +} + +// On chains a new expectation description onto the mocked interface. This +// allows syntax like. +// +// Mock. +// On("MyMethod", 1).Return(nil). +// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) +//go:noinline +func (c *Call) On(methodName string, arguments ...interface{}) *Call { + return c.Parent.On(methodName, arguments...) +} + +// Mock is the workhorse used to track activity on another object. +// For an example of its usage, refer to the "Example Usage" section at the top +// of this document. +type Mock struct { + // Represents the calls that are expected of + // an object. + ExpectedCalls []*Call + + // Holds the calls that were made to this mocked object. + Calls []Call + + // test is An optional variable that holds the test struct, to be used when an + // invalid mock call was made. + test TestingT + + // TestData holds any data that might be useful for testing. Testify ignores + // this data completely allowing you to do whatever you like with it. + testData objx.Map + + mutex sync.Mutex +} + +// TestData holds any data that might be useful for testing. Testify ignores +// this data completely allowing you to do whatever you like with it. +func (m *Mock) TestData() objx.Map { + + if m.testData == nil { + m.testData = make(objx.Map) + } + + return m.testData +} + +/* + Setting expectations +*/ + +// Test sets the test struct variable of the mock object +func (m *Mock) Test(t TestingT) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.test = t +} + +// fail fails the current test with the given formatted format and args. +// In case that a test was defined, it uses the test APIs for failing a test, +// otherwise it uses panic. +func (m *Mock) fail(format string, args ...interface{}) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.test == nil { + panic(fmt.Sprintf(format, args...)) + } + m.test.Errorf(format, args...) + m.test.FailNow() +} + +// On starts a description of an expectation of the specified method +// being called. +// +// Mock.On("MyMethod", arg1, arg2) +func (m *Mock) On(methodName string, arguments ...interface{}) *Call { + for _, arg := range arguments { + if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { + panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) + } + } + + m.mutex.Lock() + defer m.mutex.Unlock() + c := newCall(m, methodName, assert.CallerInfo(), arguments...) + m.ExpectedCalls = append(m.ExpectedCalls, c) + return c +} + +// /* +// Recording and responding to activity +// */ + +func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) { + var expectedCall *Call + + for i, call := range m.ExpectedCalls { + if call.Method == method { + _, diffCount := call.Arguments.Diff(arguments) + if diffCount == 0 { + expectedCall = call + if call.Repeatability > -1 { + return i, call + } + } + } + } + + return -1, expectedCall +} + +func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) { + var diffCount int + var closestCall *Call + var err string + + for _, call := range m.expectedCalls() { + if call.Method == method { + + errInfo, tempDiffCount := call.Arguments.Diff(arguments) + if tempDiffCount < diffCount || diffCount == 0 { + diffCount = tempDiffCount + closestCall = call + err = errInfo + } + + } + } + + return closestCall, err +} + +func callString(method string, arguments Arguments, includeArgumentValues bool) string { + + var argValsString string + if includeArgumentValues { + var argVals []string + for argIndex, arg := range arguments { + argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg)) + } + argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t")) + } + + return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString) +} + +// Called tells the mock object that a method has been called, and gets an array +// of arguments to return. Panics if the call is unexpected (i.e. not preceded by +// appropriate .On .Return() calls) +// If Call.WaitFor is set, blocks until the channel is closed or receives a message. +func (m *Mock) Called(arguments ...interface{}) Arguments { + // get the calling function's name + pc, _, _, ok := runtime.Caller(1) + if !ok { + panic("Couldn't get the caller information") + } + functionPath := runtime.FuncForPC(pc).Name() + //Next four lines are required to use GCCGO function naming conventions. + //For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock + //uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree + //With GCCGO we need to remove interface information starting from pN
. + re := regexp.MustCompile("\\.pN\\d+_") + if re.MatchString(functionPath) { + functionPath = re.Split(functionPath, -1)[0] + } + parts := strings.Split(functionPath, ".") + functionName := parts[len(parts)-1] + return m.MethodCalled(functionName, arguments...) +} + +// MethodCalled tells the mock object that the given method has been called, and gets +// an array of arguments to return. Panics if the call is unexpected (i.e. not preceded +// by appropriate .On .Return() calls) +// If Call.WaitFor is set, blocks until the channel is closed or receives a message. +func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments { + m.mutex.Lock() + //TODO: could combine expected and closes in single loop + found, call := m.findExpectedCall(methodName, arguments...) + + if found < 0 { + // expected call found but it has already been called with repeatable times + if call != nil { + m.mutex.Unlock() + m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(\"%s\").Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo()) + } + // we have to fail here - because we don't know what to do + // as the return arguments. This is because: + // + // a) this is a totally unexpected call to this method, + // b) the arguments are not what was expected, or + // c) the developer has forgotten to add an accompanying On...Return pair. + closestCall, mismatch := m.findClosestCall(methodName, arguments...) + m.mutex.Unlock() + + if closestCall != nil { + m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s", + callString(methodName, arguments, true), + callString(methodName, closestCall.Arguments, true), + diffArguments(closestCall.Arguments, arguments), + strings.TrimSpace(mismatch), + ) + } else { + m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo()) + } + } + + if call.Repeatability == 1 { + call.Repeatability = -1 + } else if call.Repeatability > 1 { + call.Repeatability-- + } + call.totalCalls++ + + // add the call + m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...)) + m.mutex.Unlock() + + // block if specified + if call.WaitFor != nil { + <-call.WaitFor + } else { + time.Sleep(call.waitTime) + } + + m.mutex.Lock() + runFn := call.RunFn + m.mutex.Unlock() + + if runFn != nil { + runFn(arguments) + } + + m.mutex.Lock() + returnArgs := call.ReturnArguments + m.mutex.Unlock() + + return returnArgs +} + +/* + Assertions +*/ + +type assertExpectationser interface { + AssertExpectations(TestingT) bool +} + +// AssertExpectationsForObjects asserts that everything specified with On and Return +// of the specified objects was in fact called as expected. +// +// Calls may have occurred in any order. +func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + for _, obj := range testObjects { + if m, ok := obj.(Mock); ok { + t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)") + obj = &m + } + m := obj.(assertExpectationser) + if !m.AssertExpectations(t) { + t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m)) + return false + } + } + return true +} + +// AssertExpectations asserts that everything specified with On and Return was +// in fact called as expected. Calls may have occurred in any order. +func (m *Mock) AssertExpectations(t TestingT) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + var somethingMissing bool + var failedExpectations int + + // iterate through each expectation + expectedCalls := m.expectedCalls() + for _, expectedCall := range expectedCalls { + if !expectedCall.optional && !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 { + somethingMissing = true + failedExpectations++ + t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo) + } else { + if expectedCall.Repeatability > 0 { + somethingMissing = true + failedExpectations++ + t.Logf("FAIL:\t%s(%s)\n\t\tat: %s", expectedCall.Method, expectedCall.Arguments.String(), expectedCall.callerInfo) + } else { + t.Logf("PASS:\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String()) + } + } + } + + if somethingMissing { + t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo()) + } + + return !somethingMissing +} + +// AssertNumberOfCalls asserts that the method was called expectedCalls times. +func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + var actualCalls int + for _, call := range m.calls() { + if call.Method == methodName { + actualCalls++ + } + } + return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls)) +} + +// AssertCalled asserts that the method was called. +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + if !m.methodWasCalled(methodName, arguments) { + var calledWithArgs []string + for _, call := range m.calls() { + calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments)) + } + if len(calledWithArgs) == 0 { + return assert.Fail(t, "Should have called with given arguments", + fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments)) + } + return assert.Fail(t, "Should have called with given arguments", + fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n %v", methodName, arguments, strings.Join(calledWithArgs, "\n"))) + } + return true +} + +// AssertNotCalled asserts that the method was not called. +// It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method. +func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + m.mutex.Lock() + defer m.mutex.Unlock() + if m.methodWasCalled(methodName, arguments) { + return assert.Fail(t, "Should not have called with given arguments", + fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments)) + } + return true +} + +func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { + for _, call := range m.calls() { + if call.Method == methodName { + + _, differences := Arguments(expected).Diff(call.Arguments) + + if differences == 0 { + // found the expected call + return true + } + + } + } + // we didn't find the expected call + return false +} + +func (m *Mock) expectedCalls() []*Call { + return append([]*Call{}, m.ExpectedCalls...) +} + +func (m *Mock) calls() []Call { + return append([]Call{}, m.Calls...) +} + +/* + Arguments +*/ + +// Arguments holds an array of method arguments or return values. +type Arguments []interface{} + +const ( + // Anything is used in Diff and Assert when the argument being tested + // shouldn't be taken into consideration. + Anything = "mock.Anything" +) + +// AnythingOfTypeArgument is a string that contains the type of an argument +// for use when type checking. Used in Diff and Assert. +type AnythingOfTypeArgument string + +// AnythingOfType returns an AnythingOfTypeArgument object containing the +// name of the type to check for. Used in Diff and Assert. +// +// For example: +// Assert(t, AnythingOfType("string"), AnythingOfType("int")) +func AnythingOfType(t string) AnythingOfTypeArgument { + return AnythingOfTypeArgument(t) +} + +// argumentMatcher performs custom argument matching, returning whether or +// not the argument is matched by the expectation fixture function. +type argumentMatcher struct { + // fn is a function which accepts one argument, and returns a bool. + fn reflect.Value +} + +func (f argumentMatcher) Matches(argument interface{}) bool { + expectType := f.fn.Type().In(0) + expectTypeNilSupported := false + switch expectType.Kind() { + case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr: + expectTypeNilSupported = true + } + + argType := reflect.TypeOf(argument) + var arg reflect.Value + if argType == nil { + arg = reflect.New(expectType).Elem() + } else { + arg = reflect.ValueOf(argument) + } + + if argType == nil && !expectTypeNilSupported { + panic(errors.New("attempting to call matcher with nil for non-nil expected type")) + } + if argType == nil || argType.AssignableTo(expectType) { + result := f.fn.Call([]reflect.Value{arg}) + return result[0].Bool() + } + return false +} + +func (f argumentMatcher) String() string { + return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name()) +} + +// MatchedBy can be used to match a mock call based on only certain properties +// from a complex struct or some calculation. It takes a function that will be +// evaluated with the called argument and will return true when there's a match +// and false otherwise. +// +// Example: +// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" })) +// +// |fn|, must be a function accepting a single argument (of the expected type) +// which returns a bool. If |fn| doesn't match the required signature, +// MatchedBy() panics. +func MatchedBy(fn interface{}) argumentMatcher { + fnType := reflect.TypeOf(fn) + + if fnType.Kind() != reflect.Func { + panic(fmt.Sprintf("assert: arguments: %s is not a func", fn)) + } + if fnType.NumIn() != 1 { + panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn)) + } + if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool { + panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn)) + } + + return argumentMatcher{fn: reflect.ValueOf(fn)} +} + +// Get Returns the argument at the specified index. +func (args Arguments) Get(index int) interface{} { + if index+1 > len(args) { + panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args))) + } + return args[index] +} + +// Is gets whether the objects match the arguments specified. +func (args Arguments) Is(objects ...interface{}) bool { + for i, obj := range args { + if obj != objects[i] { + return false + } + } + return true +} + +// Diff gets a string describing the differences between the arguments +// and the specified objects. +// +// Returns the diff string and number of differences found. +func (args Arguments) Diff(objects []interface{}) (string, int) { + //TODO: could return string as error and nil for No difference + + var output = "\n" + var differences int + + var maxArgCount = len(args) + if len(objects) > maxArgCount { + maxArgCount = len(objects) + } + + for i := 0; i < maxArgCount; i++ { + var actual, expected interface{} + var actualFmt, expectedFmt string + + if len(objects) <= i { + actual = "(Missing)" + actualFmt = "(Missing)" + } else { + actual = objects[i] + actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + } + + if len(args) <= i { + expected = "(Missing)" + expectedFmt = "(Missing)" + } else { + expected = args[i] + expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + } + + if matcher, ok := expected.(argumentMatcher); ok { + if matcher.Matches(actual) { + output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) + } else { + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) + } + } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { + + // type checking + if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) + } + + } else { + + // normal checking + + if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { + // match + output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) + } else { + // not match + differences++ + output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) + } + } + + } + + if differences == 0 { + return "No differences.", differences + } + + return output, differences + +} + +// Assert compares the arguments with the specified objects and fails if +// they do not exactly match. +func (args Arguments) Assert(t TestingT, objects ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + // get the differences + diff, diffCount := args.Diff(objects) + + if diffCount == 0 { + return true + } + + // there are differences... report them... + t.Logf(diff) + t.Errorf("%sArguments do not match.", assert.CallerInfo()) + + return false + +} + +// String gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +// +// If no index is provided, String() returns a complete string representation +// of the arguments. +func (args Arguments) String(indexOrNil ...int) string { + + if len(indexOrNil) == 0 { + // normal String() method - return a string representation of the args + var argsStr []string + for _, arg := range args { + argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg))) + } + return strings.Join(argsStr, ",") + } else if len(indexOrNil) == 1 { + // Index has been specified - get the argument at that index + var index = indexOrNil[0] + var s string + var ok bool + if s, ok = args.Get(index).(string); !ok { + panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index))) + } + return s + } + + panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil))) + +} + +// Int gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Int(index int) int { + var s int + var ok bool + if s, ok = args.Get(index).(int); !ok { + panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +// Error gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Error(index int) error { + obj := args.Get(index) + var s error + var ok bool + if obj == nil { + return nil + } + if s, ok = obj.(error); !ok { + panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +// Bool gets the argument at the specified index. Panics if there is no argument, or +// if the argument is of the wrong type. +func (args Arguments) Bool(index int) bool { + var s bool + var ok bool + if s, ok = args.Get(index).(bool); !ok { + panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index))) + } + return s +} + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +func diffArguments(expected Arguments, actual Arguments) string { + if len(expected) != len(actual) { + return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual)) + } + + for x := range expected { + if diffString := diff(expected[x], actual[x]); diffString != "" { + return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString) + } + } + + return "" +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice or array. Otherwise it returns an empty string. +func diff(expected interface{}, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { + return "" + } + + e := spewConfig.Sdump(expected) + a := spewConfig.Sdump(actual) + + diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(e), + B: difflib.SplitLines(a), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return diff +} + +var spewConfig = spew.ConfigState{ + Indent: " ", + DisablePointerAddresses: true, + DisableCapacities: true, + SortKeys: true, +} + +type tHelper interface { + Helper() +} diff --git a/vendor/golang.org/x/net/internal/socks/socks.go b/vendor/golang.org/x/net/internal/socks/socks.go index 6929a9fd5..97db2340e 100644 --- a/vendor/golang.org/x/net/internal/socks/socks.go +++ b/vendor/golang.org/x/net/internal/socks/socks.go @@ -127,7 +127,7 @@ type Dialer struct { // establishing the transport connection. ProxyDial func(context.Context, string, string) (net.Conn, error) - // AuthMethods specifies the list of request authention + // AuthMethods specifies the list of request authentication // methods. // If empty, SOCKS client requests only AuthMethodNotRequired. AuthMethods []AuthMethod diff --git a/vendor/modules.txt b/vendor/modules.txt index ecea517de..4ad29bf67 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -40,6 +40,8 @@ github.com/matttproud/golang_protobuf_extensions/pbutil github.com/mgutz/ansi # github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-homedir +# github.com/pkg/errors v0.8.1 +github.com/pkg/errors # github.com/pkg/profile v1.3.0 github.com/pkg/profile # github.com/pmezard/go-difflib v1.0.0 @@ -51,7 +53,7 @@ github.com/prometheus/client_golang/prometheus/promauto github.com/prometheus/client_golang/prometheus/internal # github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 github.com/prometheus/client_model/go -# github.com/prometheus/common v0.6.0 +# github.com/prometheus/common v0.7.0 github.com/prometheus/common/log github.com/prometheus/common/expfmt github.com/prometheus/common/model @@ -62,10 +64,11 @@ github.com/prometheus/procfs/internal/fs # github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus github.com/sirupsen/logrus/hooks/syslog -# github.com/skycoin/dmsg v0.0.0-20190904181013-b781e3cbebc6 +# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f => ../dmsg github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc +github.com/skycoin/dmsg/netutil github.com/skycoin/dmsg/noise github.com/skycoin/dmsg/ioutil # github.com/skycoin/skycoin v0.26.0 @@ -79,12 +82,15 @@ github.com/skycoin/skycoin/src/cipher/secp256k1-go/secp256k1-go2 github.com/spf13/cobra # github.com/spf13/pflag v1.0.3 github.com/spf13/pflag +# github.com/stretchr/objx v0.1.1 +github.com/stretchr/objx # github.com/stretchr/testify v1.4.0 github.com/stretchr/testify/require +github.com/stretchr/testify/mock github.com/stretchr/testify/assert # go.etcd.io/bbolt v1.3.3 go.etcd.io/bbolt -# golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 +# golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 golang.org/x/crypto/ssh/terminal golang.org/x/crypto/blake2b golang.org/x/crypto/blake2s @@ -93,7 +99,7 @@ golang.org/x/crypto/curve25519 golang.org/x/crypto/internal/chacha20 golang.org/x/crypto/internal/subtle golang.org/x/crypto/poly1305 -# golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 +# golang.org/x/net v0.0.0-20190916140828-c8589233b77d golang.org/x/net/nettest golang.org/x/net/context golang.org/x/net/proxy