diff --git a/main.go b/main.go index c3ab4e5..fb22342 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,9 @@ package main import ( // "bytes" "bytes" - "fmt" + // "fmt" + // "io" + // "log" "time" "github.com/lmnzx/lemonfs/p2p" @@ -43,13 +45,12 @@ func main() { time.Sleep(1 * time.Second) - for i := 0; i < 10; i++ { - data := bytes.NewReader([]byte("big data")) - s2.Store(fmt.Sprintf("privatedata_key_%d", i), data) - time.Sleep(1 * time.Millisecond) - } + data := bytes.NewReader([]byte("big data")) + s2.Store("privatedata_ttt", data) + + select {} - // r, err := s1.Get("privatekekw") + // r, err := s1.Get("privatedata") // if err != nil { // log.Fatal(err) // } @@ -60,6 +61,4 @@ func main() { // } // // fmt.Println(string(b)) - - select {} } diff --git a/server.go b/server.go index 74d27df..2294714 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "encoding/binary" "encoding/gob" "fmt" "io" @@ -59,10 +60,13 @@ type MessageGetFile struct { func (s *FileServer) Get(key string) (io.Reader, error) { if s.store.Has(key) { - return s.store.Read(key) + fmt.Printf("[%s] serving (%s) from local disk\n", s.Transport.Addr(), key) + + _, r, err := s.store.Read(key) + return r, err } - fmt.Printf("key: %s not found locally, fetching from network...\n", key) + fmt.Printf("[%s] doest have file (%s), fetching from network...\n", s.Transport.Addr(), key) msg := Message{ Payload: MessageGetFile{ @@ -77,21 +81,30 @@ func (s *FileServer) Get(key string) (io.Reader, error) { time.Sleep(1 * time.Millisecond) for _, peer := range s.peers { - fileBuffer := new(bytes.Buffer) - n, err := io.CopyN(fileBuffer, peer, 8) + var fileSize int64 + + binary.Read(peer, binary.LittleEndian, &fileSize) + + n, err := s.store.Write(key, io.LimitReader(peer, fileSize)) if err != nil { return nil, err } - fmt.Println("revc bytes over the network: ", n) - fmt.Println(fileBuffer.String()) - } + fmt.Printf( + "[%s] received (%d) bytes over the network from (%s)\n", + s.Transport.Addr(), + n, + peer.RemoteAddr(), + ) - select {} - - return nil, nil + peer.CloseStream() + } + _, r, err := s.store.Read(key) + return r, err } +func (s *FileServer) Remove() {} + // Store the file to the disk and broadcast to the peers func (s *FileServer) Store(key string, r io.Reader) error { var ( @@ -117,29 +130,42 @@ func (s *FileServer) Store(key string, r io.Reader) error { time.Sleep(1 * time.Millisecond) - for _, peer := range s.peers { - peer.Send([]byte{p2p.IncomingStream}) - n, err := io.Copy(peer, fileBuffer) - if err != nil { - return err - } - - fmt.Printf("recv and written %d bytes to the disk\n", n) - } - - return nil -} - -func (s *FileServer) stream(fileBuffer *bytes.Buffer) error { + // for _, peer := range s.peers { + // peer.Send([]byte{p2p.IncomingStream}) + // n, err := io.Copy(peer, fileBuffer) + // if err != nil { + // return err + // } + // + // fmt.Printf("recv and written %d bytes to the disk\n", n) + // } peers := []io.Writer{} for _, peer := range s.peers { peers = append(peers, peer) } - mw := io.MultiWriter(peers...) - return gob.NewEncoder(mw).Encode(fileBuffer) + mw.Write([]byte{p2p.IncomingStream}) + + mw.Write(fileBuffer.Bytes()) + // if err := gob.NewEncoder(mw).Encode(fileBuffer); err != nil { + // return err + // } + // FIXME + + fmt.Printf("[%s] received and written bytes to disk\n", s.Transport.Addr()) + return nil } +// func (s *FileServer) stream(fileBuffer *bytes.Buffer) error { +// peers := []io.Writer{} +// for _, peer := range s.peers { +// peers = append(peers, peer) +// } +// +// mw := io.MultiWriter(peers...) +// return gob.NewEncoder(mw).Encode(fileBuffer) +// } + func (s *FileServer) broadcast(msg *Message) error { buf := new(bytes.Buffer) if err := gob.NewEncoder(buf).Encode(msg); err != nil { @@ -211,22 +237,30 @@ func (s *FileServer) handleMessageGetFile(from string, msg MessageGetFile) error } fmt.Printf("serving file (%s) over the network\n", msg.Key) - r, err := s.store.Read(msg.Key) + fileSize, r, err := s.store.Read(msg.Key) if err != nil { return err } + if rc, ok := r.(io.ReadCloser); ok { + defer rc.Close() + } + peer, ok := s.peers[from] if !ok { return fmt.Errorf("peer (%s) could not be found in the peer list", from) } + peer.Send([]byte{p2p.IncomingStream}) + + binary.Write(peer, binary.LittleEndian, fileSize) + n, err := io.Copy(peer, r) if err != nil { return err } - fmt.Printf("read %d byes over the network to %s\n", n, from) + fmt.Printf("read %d bytes over the network to %s\n", n, from) return nil } diff --git a/store.go b/store.go index eb5bc56..67a5d7e 100644 --- a/store.go +++ b/store.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/sha1" "encoding/hex" "errors" @@ -114,24 +113,25 @@ func (s *Store) Write(key string, r io.Reader) (int64, error) { return s.writeStream(key, r) } -func (s *Store) Read(key string) (io.Reader, error) { - f, err := s.readStream(key) - if err != nil { - return nil, err - } - - defer f.Close() - - buf := new(bytes.Buffer) - _, err = io.Copy(buf, f) - - return buf, err +func (s *Store) Read(key string) (int64, io.Reader, error) { + return s.readStream(key) } -func (s *Store) readStream(key string) (io.ReadCloser, error) { +func (s *Store) readStream(key string) (int64, io.ReadCloser, error) { pathKey := s.PathTransformFunc(key) fullPathWithRoot := fmt.Sprintf("%s/%s", s.Root, pathKey.FullPath()) - return os.Open(fullPathWithRoot) + + file, err := os.Open(fullPathWithRoot) + if err != nil { + return 0, nil, err + } + + fi, err := file.Stat() + if err != nil { + return 0, nil, err + } + + return fi.Size(), file, nil } func (s *Store) writeStream(key string, r io.Reader) (int64, error) { diff --git a/store_test.go b/store_test.go index f222a9c..0195b27 100644 --- a/store_test.go +++ b/store_test.go @@ -32,7 +32,7 @@ func TestStore(t *testing.T) { data := []byte("a lot of data") // Write test - if err := s.Write(key, bytes.NewReader(data)); err != nil { + if _, err := s.Write(key, bytes.NewReader(data)); err != nil { t.Error(err) } fmt.Println("write test passed ✅") @@ -44,7 +44,7 @@ func TestStore(t *testing.T) { fmt.Println("has_exists test passed ✅") // Read test - r, err := s.Read(key) + _, r, err := s.Read(key) if err != nil { t.Error(err) }