From 8a96082f1f0522bd831d4b59db6b17944976758a Mon Sep 17 00:00:00 2001 From: Sayan Mallick Date: Tue, 23 Apr 2024 14:10:38 +0530 Subject: [PATCH 1/2] encryption setup, server broken --- crypto.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ crypto_test.go | 31 +++++++++++++++++++++ main.go | 41 +++++++++++++++------------ server.go | 40 ++++++++------------------- store.go | 23 ++++++++++++++-- store_test.go | 2 +- 6 files changed, 162 insertions(+), 50 deletions(-) create mode 100644 crypto.go create mode 100644 crypto_test.go diff --git a/crypto.go b/crypto.go new file mode 100644 index 0000000..e00b41f --- /dev/null +++ b/crypto.go @@ -0,0 +1,75 @@ +package main + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "io" +) + +func newEncryptionKey() []byte { + keyBuf := make([]byte, 32) + io.ReadFull(rand.Reader, keyBuf) + return keyBuf +} + +func copyStream(stream cipher.Stream, blockSize int, src io.Reader, dst io.Writer) (int, error) { + var ( + buf = make([]byte, 32*1024) + nw = blockSize + ) + for { + n, err := src.Read(buf) + if n > 0 { + stream.XORKeyStream(buf, buf[:n]) + nn, err := dst.Write(buf[:n]) + if err != nil { + return 0, err + } + nw += nn + } + if err == io.EOF { + break + } + if err != nil { + return 0, err + } + } + return nw, nil +} + +func copyDecrypt(key []byte, src io.Reader, dst io.Writer) (int, error) { + block, err := aes.NewCipher(key) + if err != nil { + return 0, err + } + + // read the iv + iv := make([]byte, block.BlockSize()) + if _, err := src.Read(iv); err != nil { + return 0, err + } + + stream := cipher.NewCTR(block, iv) + return copyStream(stream, block.BlockSize(), src, dst) +} + +func copyEncrypt(key []byte, src io.Reader, dst io.Writer) (int, error) { + block, err := aes.NewCipher(key) + if err != nil { + return 0, err + } + + iv := make([]byte, block.BlockSize()) + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return 0, err + } + + // pprepent iv to the file + if _, err := dst.Write(iv); err != nil { + return 0, nil + } + + stream := cipher.NewCTR(block, iv) + return copyStream(stream, block.BlockSize(), src, dst) +} diff --git a/crypto_test.go b/crypto_test.go new file mode 100644 index 0000000..30a6d28 --- /dev/null +++ b/crypto_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestCopyEncryptDecrypt(t *testing.T) { + payload := "lemon not lime" + src := bytes.NewReader([]byte(payload)) + dst := new(bytes.Buffer) + key := newEncryptionKey() + _, err := copyEncrypt(key, src, dst) + if err != nil { + t.Error(err) + } + + out := new(bytes.Buffer) + nw, err := copyDecrypt(key, dst, out) + if err != nil { + t.Error(err) + } + + if nw != 16+len(payload) { + t.Fail() + } + + if out.String() != payload { + t.Error("encryption/decryption failed") + } +} diff --git a/main.go b/main.go index fb22342..28da099 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,10 @@ package main import ( - // "bytes" "bytes" - // "fmt" - // "io" - // "log" + "fmt" + "io" + "log" "time" "github.com/lmnzx/lemonfs/p2p" @@ -21,7 +20,8 @@ func makeServer(listenAddr string, nodes ...string) *FileServer { t := p2p.NewTCPTransport(tcpTransportOpts) fileServerOpts := FileServerOpts{ - StorageRoot: listenAddr + "_network", + EncKey: newEncryptionKey(), + StorageRoot: listenAddr + "_lemonfs", PathTransformFunc: CASPathTransformFunc, Transport: t, BootstrapNodes: nodes, @@ -45,20 +45,25 @@ func main() { time.Sleep(1 * time.Second) + key := "privatedata_ttt" + data := bytes.NewReader([]byte("big data")) - s2.Store("privatedata_ttt", data) + s1.Store(key, data) - select {} + if err := s1.store.Delete(key); err != nil { + log.Fatal(err) + } + + r, err := s1.Get(key) + if err != nil { + log.Fatal(err) + } + + b, err := io.ReadAll(r) + if err != nil { + log.Fatal(err) + } - // r, err := s1.Get("privatedata") - // if err != nil { - // log.Fatal(err) - // } - // - // b, err := io.ReadAll(r) - // if err != nil { - // log.Fatal(err) - // } - // - // fmt.Println(string(b)) + fmt.Println("------>", string(b)) + select {} } diff --git a/server.go b/server.go index 2294714..84e6e2e 100644 --- a/server.go +++ b/server.go @@ -10,12 +10,11 @@ import ( "sync" "time" - // "time" - "github.com/lmnzx/lemonfs/p2p" ) type FileServerOpts struct { + EncKey []byte StorageRoot string PathTransformFunc PathTransformFunc Transport p2p.Transport @@ -85,7 +84,11 @@ func (s *FileServer) Get(key string) (io.Reader, error) { binary.Read(peer, binary.LittleEndian, &fileSize) - n, err := s.store.Write(key, io.LimitReader(peer, fileSize)) + // n, err := s.store.Write(key, io.LimitReader(peer, fileSize)) + // if err != nil { + // return nil, err + // } + n, err := s.store.WriteDecypt(s.EncKey, key, io.LimitReader(peer, fileSize)) if err != nil { return nil, err } @@ -130,15 +133,6 @@ func (s *FileServer) Store(key string, r io.Reader) error { time.Sleep(1 * time.Millisecond) - // for _, peer := range s.peers { - // peer.Send([]byte{p2p.IncomingStream}) - // n, err := io.Copy(peer, fileBuffer) - // if err != nil { - // return err - // } - // - // fmt.Printf("recv and written %d bytes to the disk\n", n) - // } peers := []io.Writer{} for _, peer := range s.peers { peers = append(peers, peer) @@ -146,26 +140,16 @@ func (s *FileServer) Store(key string, r io.Reader) error { mw := io.MultiWriter(peers...) mw.Write([]byte{p2p.IncomingStream}) - mw.Write(fileBuffer.Bytes()) - // if err := gob.NewEncoder(mw).Encode(fileBuffer); err != nil { - // return err - // } - // FIXME + n, err := copyEncrypt(s.EncKey, fileBuffer, mw) + if err != nil { + return err + } + + fmt.Printf("[%s] received and written (%d) bytes to disk\n", s.Transport.Addr(), n) - fmt.Printf("[%s] received and written bytes to disk\n", s.Transport.Addr()) return nil } -// func (s *FileServer) stream(fileBuffer *bytes.Buffer) error { -// peers := []io.Writer{} -// for _, peer := range s.peers { -// peers = append(peers, peer) -// } -// -// mw := io.MultiWriter(peers...) -// return gob.NewEncoder(mw).Encode(fileBuffer) -// } - func (s *FileServer) broadcast(msg *Message) error { buf := new(bytes.Buffer) if err := gob.NewEncoder(buf).Encode(msg); err != nil { diff --git a/store.go b/store.go index 67a5d7e..0b359a1 100644 --- a/store.go +++ b/store.go @@ -134,7 +134,7 @@ func (s *Store) readStream(key string) (int64, io.ReadCloser, error) { return fi.Size(), file, nil } -func (s *Store) writeStream(key string, r io.Reader) (int64, error) { +func (s *Store) WriteDecypt(encKey []byte, key string, r io.Reader) (int64, error) { pathKey := s.PathTransformFunc(key) pathNameWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.Pathname) if err := os.MkdirAll(pathNameWithRoot, os.ModePerm); err != nil { @@ -148,10 +148,27 @@ func (s *Store) writeStream(key string, r io.Reader) (int64, error) { return 0, err } - n, err := io.Copy(f, r) + n, err := copyDecrypt(encKey, r, f) + if err != nil { + return 0, err + } + + return int64(n), nil +} + +func (s *Store) writeStream(key string, r io.Reader) (int64, error) { + pathKey := s.PathTransformFunc(key) + pathNameWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.Pathname) + if err := os.MkdirAll(pathNameWithRoot, os.ModePerm); err != nil { + return 0, err + } + + fullPathWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.FullPath()) + + f, err := os.Create(fullPathWithRoot) if err != nil { return 0, err } - return n, nil + return io.Copy(f, r) } diff --git a/store_test.go b/store_test.go index 0195b27..784c76e 100644 --- a/store_test.go +++ b/store_test.go @@ -25,7 +25,7 @@ func TestStore(t *testing.T) { defer teardown(t, s) - for i := 0; i < 50; i++ { + for i := 0; i < 1; i++ { fmt.Printf("running test %d 🚀\n", i) key := fmt.Sprintf("test_%d", i) From 450beabdb2287a96e2c17d6e81eb5fa7e5f758d6 Mon Sep 17 00:00:00 2001 From: Sayan Mallick Date: Tue, 23 Apr 2024 14:36:38 +0530 Subject: [PATCH 2/2] fixed server --- main.go | 5 ++++- p2p/tcp_transport.go | 2 +- p2p/transport.go | 2 +- server.go | 42 ++++++++++++++++++++++++------------------ store.go | 20 ++++++++------------ 5 files changed, 38 insertions(+), 33 deletions(-) diff --git a/main.go b/main.go index 28da099..6299dc9 100644 --- a/main.go +++ b/main.go @@ -47,12 +47,15 @@ func main() { key := "privatedata_ttt" - data := bytes.NewReader([]byte("big data")) + data := bytes.NewReader([]byte("big data askdjflkasjdfl askdfjlaksjdfklajsdflkajdf")) s1.Store(key, data) if err := s1.store.Delete(key); err != nil { log.Fatal(err) } + fmt.Println("File deleted from s1") + + time.Sleep(10 * time.Second) r, err := s1.Get(key) if err != nil { diff --git a/p2p/tcp_transport.go b/p2p/tcp_transport.go index de06546..0e49c84 100644 --- a/p2p/tcp_transport.go +++ b/p2p/tcp_transport.go @@ -77,7 +77,7 @@ func (t *TCPTransport) Addr() string { } // Dail implements the Transport interface -func (t *TCPTransport) Dail(addr string) error { +func (t *TCPTransport) Dial(addr string) error { conn, err := net.Dial("tcp", addr) if err != nil { return err diff --git a/p2p/transport.go b/p2p/transport.go index 2a6fed3..d36d28e 100644 --- a/p2p/transport.go +++ b/p2p/transport.go @@ -13,7 +13,7 @@ type Peer interface { // Can be TCP, UDP, websockets, ... type Transport interface { Addr() string - Dail(string) error + Dial(string) error ListenAndAccept() error Consume() <-chan RPC Close() error diff --git a/server.go b/server.go index 84e6e2e..a579f42 100644 --- a/server.go +++ b/server.go @@ -77,17 +77,13 @@ func (s *FileServer) Get(key string) (io.Reader, error) { return nil, err } - time.Sleep(1 * time.Millisecond) + time.Sleep(100 * time.Millisecond) for _, peer := range s.peers { var fileSize int64 binary.Read(peer, binary.LittleEndian, &fileSize) - // n, err := s.store.Write(key, io.LimitReader(peer, fileSize)) - // if err != nil { - // return nil, err - // } n, err := s.store.WriteDecypt(s.EncKey, key, io.LimitReader(peer, fileSize)) if err != nil { return nil, err @@ -106,7 +102,8 @@ func (s *FileServer) Get(key string) (io.Reader, error) { return r, err } -func (s *FileServer) Remove() {} +// TODO +// func (s *FileServer) Remove() {} // Store the file to the disk and broadcast to the peers func (s *FileServer) Store(key string, r io.Reader) error { @@ -123,7 +120,7 @@ func (s *FileServer) Store(key string, r io.Reader) error { msg := Message{ Payload: MessageStoreFile{ Key: key, - Size: size, + Size: size + 16, }, } @@ -131,7 +128,7 @@ func (s *FileServer) Store(key string, r io.Reader) error { return err } - time.Sleep(1 * time.Millisecond) + time.Sleep(100 * time.Millisecond) peers := []io.Writer{} for _, peer := range s.peers { @@ -183,7 +180,7 @@ func (s *FileServer) OnPeer(p p2p.Peer) error { func (s *FileServer) loop() { defer func() { - log.Println("file server stopped due to user quit action") + log.Println("file server stopped due to error or user quit action") s.Transport.Close() }() @@ -194,7 +191,6 @@ func (s *FileServer) loop() { if err := gob.NewDecoder(bytes.NewReader(rpc.Payload)).Decode(&msg); err != nil { log.Println("decoding error: ", err) } - if err := s.handleMessage(rpc.From.String(), &msg); err != nil { log.Println("handle message error: ", err) } @@ -217,34 +213,40 @@ func (s *FileServer) handleMessage(from string, msg *Message) error { func (s *FileServer) handleMessageGetFile(from string, msg MessageGetFile) error { if !s.store.Has(msg.Key) { - return fmt.Errorf("file (%s) does not exist on disk\n", msg.Key) + return fmt.Errorf( + "[%s] need to serve file (%s) but it does not exist on disk", + s.Transport.Addr(), + msg.Key, + ) } - fmt.Printf("serving file (%s) over the network\n", msg.Key) + fmt.Printf("[%s] serving file (%s) over the network\n", s.Transport.Addr(), msg.Key) + fileSize, r, err := s.store.Read(msg.Key) if err != nil { return err } if rc, ok := r.(io.ReadCloser); ok { + fmt.Println("closing readCloser") defer rc.Close() } peer, ok := s.peers[from] if !ok { - return fmt.Errorf("peer (%s) could not be found in the peer list", from) + return fmt.Errorf("peer %s not in map", from) } + // First send the "incomingStream" byte to the peer and then we can send + // the file size as an int64. peer.Send([]byte{p2p.IncomingStream}) - binary.Write(peer, binary.LittleEndian, fileSize) - n, err := io.Copy(peer, r) if err != nil { return err } - fmt.Printf("read %d bytes over the network to %s\n", n, from) + fmt.Printf("[%s] written (%d) bytes over the network to %s\n", s.Transport.Addr(), n, from) return nil } @@ -260,7 +262,7 @@ func (s *FileServer) handleMessageStoreFile(from string, msg MessageStoreFile) e return err } - log.Printf("[%s] written (%d) bytes to disk\n", s.Transport.Addr(), n) + fmt.Printf("[%s] written %d bytes to disk\n", s.Transport.Addr(), n) peer.CloseStream() @@ -272,17 +274,21 @@ func (s *FileServer) bootstrapNetwork() error { if len(addr) == 0 { continue } + go func(addr string) { fmt.Printf("[%s] attemping to connect with remote %s\n", s.Transport.Addr(), addr) - if err := s.Transport.Dail(addr); err != nil { + if err := s.Transport.Dial(addr); err != nil { log.Println("dial error: ", err) } }(addr) } + return nil } func (s *FileServer) Start() error { + fmt.Printf("[%s] starting fileserver...\n", s.Transport.Addr()) + if err := s.Transport.ListenAndAccept(); err != nil { return err } diff --git a/store.go b/store.go index 0b359a1..255dfd3 100644 --- a/store.go +++ b/store.go @@ -135,15 +135,7 @@ func (s *Store) readStream(key string) (int64, io.ReadCloser, error) { } func (s *Store) WriteDecypt(encKey []byte, key string, r io.Reader) (int64, error) { - pathKey := s.PathTransformFunc(key) - pathNameWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.Pathname) - if err := os.MkdirAll(pathNameWithRoot, os.ModePerm); err != nil { - return 0, err - } - - fullPathWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.FullPath()) - - f, err := os.Create(fullPathWithRoot) + f, err := s.openFileForWriting(key) if err != nil { return 0, err } @@ -156,16 +148,20 @@ func (s *Store) WriteDecypt(encKey []byte, key string, r io.Reader) (int64, erro return int64(n), nil } -func (s *Store) writeStream(key string, r io.Reader) (int64, error) { +func (s *Store) openFileForWriting(key string) (*os.File, error) { pathKey := s.PathTransformFunc(key) pathNameWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.Pathname) if err := os.MkdirAll(pathNameWithRoot, os.ModePerm); err != nil { - return 0, err + return nil, err } fullPathWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.FullPath()) - f, err := os.Create(fullPathWithRoot) + return os.Create(fullPathWithRoot) +} + +func (s *Store) writeStream(key string, r io.Reader) (int64, error) { + f, err := s.openFileForWriting(key) if err != nil { return 0, err }