Skip to content

Commit

Permalink
Merge pull request #10 from lmnzx/encryption
Browse files Browse the repository at this point in the history
Encryption
  • Loading branch information
lmnzx authored Apr 23, 2024
2 parents 77204f5 + 450beab commit ce8f849
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 73 deletions.
75 changes: 75 additions & 0 deletions crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package main

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
)

func newEncryptionKey() []byte {
keyBuf := make([]byte, 32)
io.ReadFull(rand.Reader, keyBuf)
return keyBuf
}

func copyStream(stream cipher.Stream, blockSize int, src io.Reader, dst io.Writer) (int, error) {
var (
buf = make([]byte, 32*1024)
nw = blockSize
)
for {
n, err := src.Read(buf)
if n > 0 {
stream.XORKeyStream(buf, buf[:n])
nn, err := dst.Write(buf[:n])
if err != nil {
return 0, err
}
nw += nn
}
if err == io.EOF {
break
}
if err != nil {
return 0, err
}
}
return nw, nil
}

func copyDecrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
block, err := aes.NewCipher(key)
if err != nil {
return 0, err
}

// read the iv
iv := make([]byte, block.BlockSize())
if _, err := src.Read(iv); err != nil {
return 0, err
}

stream := cipher.NewCTR(block, iv)
return copyStream(stream, block.BlockSize(), src, dst)
}

func copyEncrypt(key []byte, src io.Reader, dst io.Writer) (int, error) {
block, err := aes.NewCipher(key)
if err != nil {
return 0, err
}

iv := make([]byte, block.BlockSize())
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return 0, err
}

// pprepent iv to the file
if _, err := dst.Write(iv); err != nil {
return 0, nil
}

stream := cipher.NewCTR(block, iv)
return copyStream(stream, block.BlockSize(), src, dst)
}
31 changes: 31 additions & 0 deletions crypto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"bytes"
"testing"
)

func TestCopyEncryptDecrypt(t *testing.T) {
payload := "lemon not lime"
src := bytes.NewReader([]byte(payload))
dst := new(bytes.Buffer)
key := newEncryptionKey()
_, err := copyEncrypt(key, src, dst)
if err != nil {
t.Error(err)
}

out := new(bytes.Buffer)
nw, err := copyDecrypt(key, dst, out)
if err != nil {
t.Error(err)
}

if nw != 16+len(payload) {
t.Fail()
}

if out.String() != payload {
t.Error("encryption/decryption failed")
}
}
46 changes: 27 additions & 19 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package main

import (
// "bytes"
"bytes"
// "fmt"
// "io"
// "log"
"fmt"
"io"
"log"
"time"

"github.com/lmnzx/lemonfs/p2p"
Expand All @@ -21,7 +20,8 @@ func makeServer(listenAddr string, nodes ...string) *FileServer {
t := p2p.NewTCPTransport(tcpTransportOpts)

fileServerOpts := FileServerOpts{
StorageRoot: listenAddr + "_network",
EncKey: newEncryptionKey(),
StorageRoot: listenAddr + "_lemonfs",
PathTransformFunc: CASPathTransformFunc,
Transport: t,
BootstrapNodes: nodes,
Expand All @@ -45,20 +45,28 @@ func main() {

time.Sleep(1 * time.Second)

data := bytes.NewReader([]byte("big data"))
s2.Store("privatedata_ttt", data)
key := "privatedata_ttt"

select {}
data := bytes.NewReader([]byte("big data askdjflkasjdfl askdfjlaksjdfklajsdflkajdf"))
s1.Store(key, data)

if err := s1.store.Delete(key); err != nil {
log.Fatal(err)
}
fmt.Println("File deleted from s1")

// r, err := s1.Get("privatedata")
// if err != nil {
// log.Fatal(err)
// }
//
// b, err := io.ReadAll(r)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Println(string(b))
time.Sleep(10 * time.Second)

r, err := s1.Get(key)
if err != nil {
log.Fatal(err)
}

b, err := io.ReadAll(r)
if err != nil {
log.Fatal(err)
}

fmt.Println("------>", string(b))
select {}
}
2 changes: 1 addition & 1 deletion p2p/tcp_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (t *TCPTransport) Addr() string {
}

// Dail implements the Transport interface
func (t *TCPTransport) Dail(addr string) error {
func (t *TCPTransport) Dial(addr string) error {
conn, err := net.Dial("tcp", addr)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion p2p/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Peer interface {
// Can be TCP, UDP, websockets, ...
type Transport interface {
Addr() string
Dail(string) error
Dial(string) error
ListenAndAccept() error
Consume() <-chan RPC
Close() error
Expand Down
76 changes: 33 additions & 43 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ import (
"sync"
"time"

// "time"

"github.com/lmnzx/lemonfs/p2p"
)

type FileServerOpts struct {
EncKey []byte
StorageRoot string
PathTransformFunc PathTransformFunc
Transport p2p.Transport
Expand Down Expand Up @@ -78,14 +77,14 @@ func (s *FileServer) Get(key string) (io.Reader, error) {
return nil, err
}

time.Sleep(1 * time.Millisecond)
time.Sleep(100 * time.Millisecond)

for _, peer := range s.peers {
var fileSize int64

binary.Read(peer, binary.LittleEndian, &fileSize)

n, err := s.store.Write(key, io.LimitReader(peer, fileSize))
n, err := s.store.WriteDecypt(s.EncKey, key, io.LimitReader(peer, fileSize))
if err != nil {
return nil, err
}
Expand All @@ -103,7 +102,8 @@ func (s *FileServer) Get(key string) (io.Reader, error) {
return r, err
}

func (s *FileServer) Remove() {}
// TODO
// func (s *FileServer) Remove() {}

// Store the file to the disk and broadcast to the peers
func (s *FileServer) Store(key string, r io.Reader) error {
Expand All @@ -120,52 +120,33 @@ func (s *FileServer) Store(key string, r io.Reader) error {
msg := Message{
Payload: MessageStoreFile{
Key: key,
Size: size,
Size: size + 16,
},
}

if err := s.broadcast(&msg); err != nil {
return err
}

time.Sleep(1 * time.Millisecond)

// for _, peer := range s.peers {
// peer.Send([]byte{p2p.IncomingStream})
// n, err := io.Copy(peer, fileBuffer)
// if err != nil {
// return err
// }
//
// fmt.Printf("recv and written %d bytes to the disk\n", n)
// }
time.Sleep(100 * time.Millisecond)

peers := []io.Writer{}
for _, peer := range s.peers {
peers = append(peers, peer)
}
mw := io.MultiWriter(peers...)
mw.Write([]byte{p2p.IncomingStream})

mw.Write(fileBuffer.Bytes())
// if err := gob.NewEncoder(mw).Encode(fileBuffer); err != nil {
// return err
// }
// FIXME
n, err := copyEncrypt(s.EncKey, fileBuffer, mw)
if err != nil {
return err
}

fmt.Printf("[%s] received and written (%d) bytes to disk\n", s.Transport.Addr(), n)

fmt.Printf("[%s] received and written bytes to disk\n", s.Transport.Addr())
return nil
}

// func (s *FileServer) stream(fileBuffer *bytes.Buffer) error {
// peers := []io.Writer{}
// for _, peer := range s.peers {
// peers = append(peers, peer)
// }
//
// mw := io.MultiWriter(peers...)
// return gob.NewEncoder(mw).Encode(fileBuffer)
// }

func (s *FileServer) broadcast(msg *Message) error {
buf := new(bytes.Buffer)
if err := gob.NewEncoder(buf).Encode(msg); err != nil {
Expand Down Expand Up @@ -199,7 +180,7 @@ func (s *FileServer) OnPeer(p p2p.Peer) error {

func (s *FileServer) loop() {
defer func() {
log.Println("file server stopped due to user quit action")
log.Println("file server stopped due to error or user quit action")
s.Transport.Close()
}()

Expand All @@ -210,7 +191,6 @@ func (s *FileServer) loop() {
if err := gob.NewDecoder(bytes.NewReader(rpc.Payload)).Decode(&msg); err != nil {
log.Println("decoding error: ", err)
}

if err := s.handleMessage(rpc.From.String(), &msg); err != nil {
log.Println("handle message error: ", err)
}
Expand All @@ -233,34 +213,40 @@ func (s *FileServer) handleMessage(from string, msg *Message) error {

func (s *FileServer) handleMessageGetFile(from string, msg MessageGetFile) error {
if !s.store.Has(msg.Key) {
return fmt.Errorf("file (%s) does not exist on disk\n", msg.Key)
return fmt.Errorf(
"[%s] need to serve file (%s) but it does not exist on disk",
s.Transport.Addr(),
msg.Key,
)
}

fmt.Printf("serving file (%s) over the network\n", msg.Key)
fmt.Printf("[%s] serving file (%s) over the network\n", s.Transport.Addr(), msg.Key)

fileSize, r, err := s.store.Read(msg.Key)
if err != nil {
return err
}

if rc, ok := r.(io.ReadCloser); ok {
fmt.Println("closing readCloser")
defer rc.Close()
}

peer, ok := s.peers[from]
if !ok {
return fmt.Errorf("peer (%s) could not be found in the peer list", from)
return fmt.Errorf("peer %s not in map", from)
}

// First send the "incomingStream" byte to the peer and then we can send
// the file size as an int64.
peer.Send([]byte{p2p.IncomingStream})

binary.Write(peer, binary.LittleEndian, fileSize)

n, err := io.Copy(peer, r)
if err != nil {
return err
}

fmt.Printf("read %d bytes over the network to %s\n", n, from)
fmt.Printf("[%s] written (%d) bytes over the network to %s\n", s.Transport.Addr(), n, from)

return nil
}
Expand All @@ -276,7 +262,7 @@ func (s *FileServer) handleMessageStoreFile(from string, msg MessageStoreFile) e
return err
}

log.Printf("[%s] written (%d) bytes to disk\n", s.Transport.Addr(), n)
fmt.Printf("[%s] written %d bytes to disk\n", s.Transport.Addr(), n)

peer.CloseStream()

Expand All @@ -288,17 +274,21 @@ func (s *FileServer) bootstrapNetwork() error {
if len(addr) == 0 {
continue
}

go func(addr string) {
fmt.Printf("[%s] attemping to connect with remote %s\n", s.Transport.Addr(), addr)
if err := s.Transport.Dail(addr); err != nil {
if err := s.Transport.Dial(addr); err != nil {
log.Println("dial error: ", err)
}
}(addr)
}

return nil
}

func (s *FileServer) Start() error {
fmt.Printf("[%s] starting fileserver...\n", s.Transport.Addr())

if err := s.Transport.ListenAndAccept(); err != nil {
return err
}
Expand Down
Loading

0 comments on commit ce8f849

Please sign in to comment.