Skip to content

Commit

Permalink
finally faster :)
Browse files Browse the repository at this point in the history
  • Loading branch information
elaboy committed Aug 29, 2024
1 parent 807ac11 commit 15971a6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 30 deletions.
2 changes: 1 addition & 1 deletion mzLib/MassSpectrometry/IRetentionTimeAlignable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ public interface IRetentionTimeAlignable
public float RetentionTime { get; set; }
public float ChronologerHI { get; set; }
public string BaseSequence { get; set; }
public string FullSequence { get; set; }
public string FullSequence { get; }
}
8 changes: 2 additions & 6 deletions mzLib/Proteomics/PSM/PsmFromTsv.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,12 @@ public class PsmFromTsv : SpectrumMatchFromTsv, IRetentionTimeAlignable
//For Aligner Interface
public string FileName { get => FileNameWithoutExtension; set => FileNameWithoutExtension = value; }
float IRetentionTimeAlignable.RetentionTime { get => (float?)RetentionTime.Value ?? -1; set => RetentionTime = value; }
float IRetentionTimeAlignable.ChronologerHI { get => (float)ChronologerHIDouble; set => SetChronologerHI(value); }
public float ChronologerHI { get ; set; }

public float SetChronologerHI(double? chronologerHI) => (float)chronologerHI.Value;
public string BaseSequence { get => BaseSeq; set => SetBaseSequence(); }
public string SetBaseSequence() => BaseSeq;
string IRetentionTimeAlignable.FullSequence { get => FullSequence; set => SetFullSequence(); }
public string SetFullSequence() => FullSequence;
string FullSequence => base.FullSequence;




public PsmFromTsv(string line, char[] split, Dictionary<string, int> parsedHeader)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using System.CodeDom.Compiler;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Threading.Tasks;
using TorchSharp;
using TorchSharp.Modules;

namespace Proteomics.RetentionTimePrediction.Chronologer
{
Expand Down Expand Up @@ -46,6 +48,7 @@ public static class ChronologerEstimator

public static float[] PredictRetentionTime(string[] baseSequences, string[] fullSequences, bool gpu)
{
// TODO try catch here
torch.InitializeDeviceType(DeviceType.CUDA);
// if cuda is available, then use it bro
var device = torch.cuda_is_available()
Expand All @@ -68,7 +71,7 @@ public static float[] PredictRetentionTime(string[] baseSequences, string[] full
// tensorize all
torch.Tensor[] tensorsArray = new torch.Tensor[baseSequences.Length];
bool[] compatibleTracker = new bool[baseSequences.Length];

Parallel.For(0, baseSequences.Length, (i, state) =>
{
var baseSeq = baseSequences[i];
Expand All @@ -87,29 +90,38 @@ public static float[] PredictRetentionTime(string[] baseSequences, string[] full
}
});
// vstack and split
var stackedTensors = torch.vstack(tensorsArray).split(100);
var stackedTensors = torch.vstack(tensorsArray);

float[] preds = new float[stackedTensors.Length];
float[] preds = new float[stackedTensors.size(0)];

// make batches
for (int batch = 0; batch < stackedTensors.Length; batch++)
using var dataset = torch.utils.data.TensorDataset(stackedTensors);
using var dataLoader = torch.utils.data.DataLoader(dataset, 2048);

var predictionHolder = new List<float>();
foreach (var batch in dataLoader)
{
var output = ChronologerModel.Predict(stackedTensors[batch].to(device));
//move output to cpu
var outputArray = output.to(DeviceType.CPU).data<float>().ToArray();
Parallel.For(0, outputArray.Length , outputIndex =>
{
preds[batch] = outputArray[outputIndex];
});
var output = ChronologerModel.Predict(torch.vstack(batch).to(device));
predictionHolder.AddRange(output.to(DeviceType.CPU).data<float>());
output.Dispose();
}

Parallel.For(0, preds.Length, outputIndex =>
{
preds[outputIndex] = predictionHolder[outputIndex];
});

//change to -1 if same index in compatibleTracker is false, else leave as is
//predictions = preds.SelectMany(x => x).ToArray();
// return vstacked tensors as a matrix => float?[]

for (var predictionsIndex = 0; predictionsIndex < predictions.Length; predictionsIndex++)
{
if (compatibleTracker[predictionsIndex]) continue;
if (compatibleTracker[predictionsIndex])
{
predictions[predictionsIndex] = predictionHolder[predictionsIndex];
continue;
}

predictions[predictionsIndex] = -1;
}
Expand All @@ -131,7 +143,7 @@ public static float[] PredictRetentionTime(string[] baseSequences, string[] full
predictions[i] = -1;
});
}

return predictions;
}
/// <summary>
Expand Down
39 changes: 31 additions & 8 deletions mzLib/RTLib/RTLibCommandLine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML;
using Nett;
using Proteomics.RetentionTimePrediction.Chronologer;
using Readers;
using TorchSharp;

namespace RTLib;
Expand Down Expand Up @@ -55,6 +56,15 @@ private static (List<string> filePaths, string outputPath) CommandLineParser(str

}

public class LightPsm : IRetentionTimeAlignable
{
public string FileName { get; set; }
public float RetentionTime { get; set; }
public float ChronologerHI { get; set; }
public string BaseSequence { get; set; }
public string FullSequence { get; set; }
}

public class RtLib
{
private List<string> ResultsPath { get; }
Expand All @@ -72,7 +82,7 @@ public RtLib(List<string> resultsPath, string outputPath, bool useChronologer)
OutputPath = outputPath;
Results = new Dictionary<string, List<float>>();

Task<List<IRetentionTimeAlignable>>[] dataLoader = new Task<List<IRetentionTimeAlignable>>[ResultsPath.Count];
Task<List<LightPsm>>[] dataLoader = new Task<List<LightPsm>>[ResultsPath.Count];

for (int i = 0; i < ResultsPath.Count; i++)
{
Expand Down Expand Up @@ -128,6 +138,7 @@ public RtLib(List<string> resultsPath, string outputPath, bool useChronologer)

Results = aligner.GetResults();
aligner.Dispose();
dataLoader[i].Dispose();
}
else
{
Expand All @@ -141,28 +152,40 @@ public RtLib(List<string> resultsPath, string outputPath, bool useChronologer)
Results.Add(fullSequence.Key, new List<float>());
Results[fullSequence.Key].AddRange(fullSequence.Select(x => x.RetentionTime));
}
dataLoader[i].Dispose();
}
Debug.WriteLine($"file: {i} of {Results.Count}");
Debug.WriteLine($"file: {i} of {ResultsPath.Count}");
//dataLoader[i].Result.Clear();
}
Write();
}

public List<IRetentionTimeAlignable> LoadFileResults(string path)
public List<LightPsm> LoadFileResults(string path)
{
var file = new Readers.PsmFromTsvFile(path);
file.LoadResults();
return file.Results

var results = file.Results
.Where(item => item.AmbiguityLevel == "1")
.Cast<IRetentionTimeAlignable>()
.ToList();

List<LightPsm> lightPsms = new List<LightPsm>();
foreach (var item in results)
{
lightPsms.Add(new LightPsm(){BaseSequence = item.BaseSequence,
ChronologerHI = item.ChronologerHI,
FileName = item.FileName,
FullSequence = item.FullSequence,
RetentionTime = (float)item.RetentionTime.Value});
}

return lightPsms;

}

public Task<List<IRetentionTimeAlignable>> LoadFileResultsAsync(string path)
public Task<List<LightPsm>> LoadFileResultsAsync(string path)
{
var results = new Task<List<IRetentionTimeAlignable>>(() => LoadFileResults(path));
var results = new Task<List<LightPsm>>(() => LoadFileResults(path));
return results;
}

Expand Down

0 comments on commit 15971a6

Please sign in to comment.