Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahide committed Sep 7, 2024
1 parent b0c69e8 commit 24276be
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 65 deletions.
58 changes: 33 additions & 25 deletions cmd/agent-bench/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"flag"
"fmt"
"log"
"net"
Expand All @@ -16,7 +17,7 @@ import (
type Specification struct {
PERSISTENT bool `default:"false"`
CONCURRENCY int `default:"10"`
RUN_COUNT int `default:"1000"`
RUN_COUNT int `default:"100"`
}

type sshAgent interface {
Expand All @@ -40,21 +41,47 @@ func (e *exAgent) Extension(string, []byte) ([]byte, error) {
return nil, nil
}

func getKey() *agent.Key {
var key *agent.Key

a, err := newAgent()
if err != nil {
log.Fatal(err)
}
keys, err := a.List()
if err != nil {
log.Fatalf("Failed to list keys: %v", err)
}
if len(keys) == 0 {
log.Fatalf("No keys found in SSH agent")
}
key = keys[0]
a.Close()
return key
}

func main() {
s := Specification{}
err := envconfig.Process("", &s)
if err != nil {
log.Fatal(err)
}
flag.BoolVar(&s.PERSISTENT, "persistent", s.PERSISTENT, "persistent mode")
flag.IntVar(&s.CONCURRENCY, "c", s.CONCURRENCY, "Number of concurrency processing")
flag.IntVar(&s.RUN_COUNT, "n", s.RUN_COUNT, "run count")
flag.Parse()
taskCh := make(chan struct{})
doneCh := make(chan []time.Duration, s.CONCURRENCY)

var wg sync.WaitGroup
key := getKey()
fmt.Printf("The key used for measurement:%s\n", key.String())
fmt.Printf("Start %d worker\n", s.CONCURRENCY)
for i := 0; i < s.CONCURRENCY; i++ {
wg.Add(1)
go func() {
defer wg.Done()
worker(taskCh, doneCh, s.PERSISTENT)
worker(key, taskCh, doneCh, s.PERSISTENT)
}()
}

Expand All @@ -74,7 +101,7 @@ func main() {
for times := range doneCh {
allExecutionTimes = append(allExecutionTimes, times...)
}

fmt.Printf("\ndone.\n")
totalTime := time.Duration(0)
var minTime, maxTime time.Duration
minTime = allExecutionTimes[0]
Expand Down Expand Up @@ -108,29 +135,10 @@ func main() {
fmt.Printf("99th Percentile Execution Time: %v\n", p99Time)
}

func worker(taskCh <-chan struct{}, doneCh chan<- []time.Duration, persistent bool) {
func worker(key *agent.Key, taskCh <-chan struct{}, doneCh chan<- []time.Duration, persistent bool) {
var executionTimes []time.Duration

var err error
var agentClient sshAgent
var key *agent.Key

a, err := newAgent()
if err != nil {
log.Fatal(err)
}
keys, err := a.List()
if err != nil {
log.Fatalf("Failed to list keys: %v", err)
}
if len(keys) == 0 {
log.Fatalf("No keys found in SSH agent")
}
key = keys[0]
a.Close()
agentClient = nil

log.Printf("key:%s", key.String())

for range taskCh {
start := time.Now()
if agentClient == nil {
Expand All @@ -151,7 +159,7 @@ func worker(taskCh <-chan struct{}, doneCh chan<- []time.Duration, persistent bo
}
duration := time.Since(start)
executionTimes = append(executionTimes, duration)
fmt.Print(".")
}

doneCh <- executionTimes
}
24 changes: 15 additions & 9 deletions cmd/wsl2-ssh-agent-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func (mux *Multiplexer) readLoop(ctx context.Context) error {
log.Println("Error reading header:", err)
return err
}
log.Printf("mux readFull header:[%v]", header)

packetType := binary.LittleEndian.Uint32(header[:4])
channelID := binary.LittleEndian.Uint32(header[4:])
Expand All @@ -88,17 +89,21 @@ func (mux *Multiplexer) readLoop(ctx context.Context) error {
log.Println("Error reading payload:", err)
return err
}
//log.Printf("mux readFull payload type:%d, ch:%d, len:%d ", packetType, channelID, length)
log.Printf("mux readFull payload type:%d, ch:%d, len:%d ", packetType, channelID, length)

switch packetType {
case PacketTypeSend:
mux.channelsMu.Lock()
if ch, ok := mux.channels[channelID]; ok {
ch, ok := mux.channels[channelID]
mux.channelsMu.Unlock()
if ok {
ch <- payload
} else {
log.Printf("mux readFull error: channel %d not found", channelID)
}
mux.channelsMu.Unlock()
case PacketTypeClose:
mux.CloseChannel(channelID)
log.Printf("mux readFull close channel %d", channelID)
}
}
}
Expand Down Expand Up @@ -138,7 +143,6 @@ func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, cha
defer conn.Close()
ch := ps.OpenChannel(channelID)
go func() {
//reader := bufio.NewReader(conn)
defer func() {
ps.WriteChannel(Packet{PacketType: PacketTypeClose, ChannelID: channelID, Payload: []byte{}})
ps.CloseChannel(channelID)
Expand All @@ -149,12 +153,12 @@ func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, cha
n, err := conn.Read(payload)
if err != nil {
if err == io.EOF {
log.Printf("DomainSocket.read ch:%d io.EOF", channelID)
break
}
log.Println("Error reading from connection:", err)
break
}
//log.Printf("handleCoonection read: channelID:%d byte[%v]", channelID, payload[:n])
ps.WriteChannel(Packet{PacketType: packetType, ChannelID: channelID, Payload: payload[:n]})
packetType = PacketTypeSend
select {
Expand All @@ -165,15 +169,17 @@ func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, cha
}
}()

writer := bufio.NewWriter(conn)
domainSocketWriter := bufio.NewWriter(conn)
for msg := range ch {
_, err := writer.Write(msg)
_, err := domainSocketWriter.Write(msg)
if err != nil {
log.Println("Error writing to connection:", err)
break
}
writer.Flush()
domainSocketWriter.Flush()
log.Printf("DomainSocketWriter.Write ch:%d len:%d", channelID, len(msg))
}
log.Printf("Close DomainSocket ch:%d", channelID)
}

type pwshIOStream struct {
Expand Down Expand Up @@ -341,7 +347,7 @@ func (ps *pwshIOStream) listenLoop(ctx context.Context, listener net.Listener) e
log.Println("Error accepting connection:", err)
continue
}
//log.Printf("accept: %v", conn.LocalAddr())
log.Printf("domainSocket:%v accept ch:%d", conn.LocalAddr(), channelID)
go ps.handleConnection(ctx, conn, channelID)
channelID++
select {
Expand Down
68 changes: 37 additions & 31 deletions cmd/wsl2-ssh-agent-proxy/pwsh.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,26 @@ $WritePacketWorker = {
[System.IO.StreamWriter] $OutputStreamWriter
)

#[Console]::Error.WriteLine("WritePacketWorker started.")
[Console]::Error.WriteLine("WritePacketWorker started.")
while ($true) {
$null = $MainPacketQueueSignal.WaitOne()
#[Console]::Error.WriteLine("WritePacketWorker: Signal received, processing packet queue.")
[Console]::Error.WriteLine("WritePacketWorker: Signal received, processing packet queue.")
$Packet = $null
if ($PacketQueue.TryDequeue([ref]$Packet)) {
#[Console]::Error.WriteLine("WritePacketWorker: Packet dequeued. Length: $($Packet.Length), Channel ID: $($Packet.ChannelID), Type: $($Packet.Type).")
while ($PacketQueue.TryDequeue([ref]$Packet)) {
[Console]::Error.WriteLine("WritePacketWorker [ch$($Packet.ChannelID),type:$($Packet.Type)]: Packet dequeued. Length: $($Packet.Length)")
$Header = [BitConverter]::GetBytes($Packet.Type) +
[BitConverter]::GetBytes($Packet.ChannelID) +
[BitConverter]::GetBytes($Packet.Payload.Length)
$OutputStreamWriter.BaseStream.Write($Header, 0, $Header.Length)
$OutputStreamWriter.BaseStream.Write($Packet.Payload, 0, $Packet.Payload.Length)
$OutputStreamWriter.Flush()
#[Console]::Error.WriteLine("WritePacketWorker: Packet written to output stream.")
try {
$OutputStreamWriter.BaseStream.Write($Header, 0, $Header.Length)
$OutputStreamWriter.BaseStream.Write($Packet.Payload, 0, $Packet.Payload.Length)
$OutputStreamWriter.Flush()
}
catch {
[Console]::Error.WriteLine("WritePacketWorker [ch$($Packet.ChannelID),type:$($Packet.Type)]: Write error:[$error]")
continue
}
[Console]::Error.WriteLine("WritePacketWorker [ch$($Packet.ChannelID),type:$($Packet.Type)]: Packet written to output stream.")
}
}
}
Expand All @@ -43,31 +49,31 @@ $PacketWorkerScript = {
[void]SendResponse([hashtable]$Packet) {
$null = $this.WorkerInstance.MainPacketQueue.Enqueue($Packet)
$null = $this.WorkerInstance.MainPacketQueueSignal.Set()
#[Console]::Error.WriteLine("PacketWorker: Response sent for Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Response sent.")
}

[void]StopWorker([Int32]$ChannelID) {
$this.SendResponse(@{ Type = 2; Payload = [byte[]]::new(0); ChannelID = $ChannelID })
$null = $this.WorkerInstance.WorkerQueue.Enqueue($this.WorkerInstance)
#[Console]::Error.WriteLine("PacketWorker: Worker stopped for Channel ID: $ChannelID.")
[Console]::Error.WriteLine("PacketWorker [ch:$($ChannelID)]: Worker stopped.")
}

[void]Run() {
#[Console]::Error.WriteLine("PacketWorker started.")
[Console]::Error.WriteLine("PacketWorker started.")
while ($true) {
$null = $this.WorkerInstance.PacketQueueSignal.WaitOne()
$Packet = $null
if ($this.WorkerInstance.PacketQueue.TryDequeue([ref]$Packet)) {
#[Console]::Error.WriteLine("PacketWorker: Packet received. Channel ID: $($Packet.ChannelID).")
while ($this.WorkerInstance.PacketQueue.TryDequeue([ref]$Packet)) {
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Packet received.")
try {
if (!$this.ProcessPacket($Packet)) {
#[Console]::Error.WriteLine("PacketWorker: Processing failed for Channel ID: $($Packet.ChannelID). Worker will stop.")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Processing failed. Worker will stop.")
$this.StopWorker($Packet.ChannelID)
continue
}
}
catch {
[Console]::Error.WriteLine("PacketWorker: Exception occurred while processing Channel ID: $($Packet.ChannelID). Error: $($_.Exception.Message). Worker will stop.")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Exception occurred while processing. Error: $($_.Exception.Message). Worker will stop.")
$this.StopWorker($Packet.ChannelID)
continue
}
Expand All @@ -76,39 +82,39 @@ $PacketWorkerScript = {
}

[bool]ProcessPacket([hashtable]$Packet) {
#[Console]::Error.WriteLine("PacketWorker: Processing packet. Type: $($Packet.TypeNum), Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Processing packet.")
if (0 -eq $Packet.TypeNum) {
if ($null -ne $this.NamedPipeStream) {
[Console]::Error.WriteLine("PacketWorker: Named pipe connection already closed. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe connection already closed.")
$this.NamedPipeStream.Close()
$this.NamedPipeStream = $null
return $false
}
$this.NamedPipeStream = [System.IO.Pipes.NamedPipeClientStream]::new(".", "openssh-ssh-agent", [System.IO.Pipes.PipeDirection]::InOut)
$this.NamedPipeStream.Connect()
$this.WorkerInstance.ChannelID = $Packet.ChannelID
#[Console]::Error.WriteLine("PacketWorker: Named pipe connection established. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe connection established.")
}
elseif (2 -eq $Packet.TypeNum) {
if ($null -eq $this.NamedPipeStream) {
#[Console]::Error.WriteLine("PacketWorker: No active named pipe connection to close. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: No active named pipe connection to close.")
return $false
}
$this.NamedPipeStream.Close()
$this.NamedPipeStream = $null
#[Console]::Error.WriteLine("PacketWorker: Named pipe connection closed. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe connection closed.")
return $false
}
$this.NamedPipeStream.Write($Packet.Payload, 0, $Packet.Payload.Length)
$this.NamedPipeStream.Flush()
#[Console]::Error.WriteLine("PacketWorker: Data written to named pipe. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Data written to named pipe.")

$Payload = [byte[]]::new(10240)
$n = $this.NamedPipeStream.Read($Payload, 0, $Payload.Length)
if ($n -gt 0) {
$Payload = $Payload[0..($n - 1)]
$this.SendResponse(@{ Type = 1; Payload = $Payload; ChannelID = $Packet.ChannelID })
#[Console]::Error.WriteLine("PacketWorker: Response read from named pipe and sent. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Response read from named pipe and sent.")
}
return $true
}
Expand Down Expand Up @@ -139,32 +145,32 @@ class PacketReader {
}

[Hashtable] ReadPacket ([System.IO.Stream] $InputStreamReader) {
#[Console]::Error.WriteLine("PacketReader: Reading packet from input stream.")
[Console]::Error.WriteLine("PacketReader: Reading packet from input stream.")
$Header = [byte[]]::new(12)
$n = $InputStreamReader.Read($Header, 0, $Header.Length)
if ($n -eq 0) {
return @{Error = "PacketReader: Failed to read header (length zero)." }
}
#[Console]::Error.WriteLine("PacketReader: Header read successfully. Length: $n.")
$Res = @{
TypeNum = [BitConverter]::ToInt32($Header, 0)
ChannelID = [BitConverter]::ToInt32($Header, 4)
Length = [BitConverter]::ToInt32($Header, 8)
Error = $null
}
[Console]::Error.WriteLine("PacketReader [ch:$($Res.ChannelID) type:$($Res.TypeNum)]: Header read successfully. Length: $n.")
$Res.Payload = [byte[]]::new($Res.Length)
$n = $InputStreamReader.Read($Res.Payload, 0, $Res.Length)
if ($n -ne $Res.Length) {
$Res = @{Error = "PacketReader: Incomplete payload read. Expected: $($Res.Length), Actual: $n. Channel ID: $($Res.ChannelID), Type: $($Res.TypeNum)." }
$Res = @{Error = "PacketReader [ch:$($Res.ChannelID) type:$($Res.TypeNum)]: Incomplete payload read. Expected: $($Res.Length), Actual: $n." }
}
#[Console]::Error.WriteLine("PacketReader: Packet read completed. Channel ID: $($Res.ChannelID), Type: $($Res.TypeNum), Length: $($Res.Length).")
[Console]::Error.WriteLine("PacketReader [ch:$($Res.ChannelID) type:$($Res.TypeNum)]: Packet read completed. Length: $($Res.Length).")
return $Res
}

[void] Run() {
[Console]::Error.WriteLine("PacketReader started.")
while ($true) {
#[Console]::Error.WriteLine("PacketReader: Waiting for packets.")
[Console]::Error.WriteLine("PacketReader: Waiting for packets.")
$Packet = $null
try {
$Packet = $this.ReadPacket($this.InputStreamReader)
Expand All @@ -178,7 +184,7 @@ class PacketReader {
Start-Sleep -Seconds 1.0
continue
}
#[Console]::Error.WriteLine("PacketReader: Packet received. Type: $($Packet.TypeNum), Channel ID: $($Packet.ChannelID), Length: $($Packet.Length).")
[Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Packet received. Length: $($Packet.Length).")
$WorkerInstance = $null
if ($this.Channels.ContainsKey($Packet.ChannelID)) {
$WorkerInstance = $this.Channels[$Packet.ChannelID]
Expand All @@ -187,7 +193,7 @@ class PacketReader {
if ($this.WorkerQueue.TryDequeue([ref]$WorkerInstance)) {
$this.Channels.Remove($WorkerInstance.ChannelID)
$WorkerInstance.ChannelID = $Packet.ChannelID
#[Console]::Error.WriteLine("PacketReader: Reusing existing worker for Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Reusing existing worker.")
}
else {
$WorkerInstance = @{
Expand All @@ -200,13 +206,13 @@ class PacketReader {
}
$null = [PowerShell]::Create().AddScript($this.PacketWorkerScript).
AddArgument($WorkerInstance).BeginInvoke()
#[Console]::Error.WriteLine("PacketReader: New worker initialized for Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: New worker initialized.")
}
$this.Channels[$WorkerInstance.ChannelID] = $WorkerInstance
}
$WorkerInstance.PacketQueue.Enqueue($Packet)
$WorkerInstance.PacketQueueSignal.Set()
#[Console]::Error.WriteLine("PacketReader: Packet dispatched to worker. Channel ID: $($Packet.ChannelID).")
[Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Packet dispatched to worker.")
}
}
}
Expand Down

0 comments on commit 24276be

Please sign in to comment.