Skip to content

Commit

Permalink
Added potential fixes to model overrider TBD at a later date.
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelalonsojr committed Oct 5, 2023
1 parent 0ef356a commit 7277a89
Showing 1 changed file with 50 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using UnityEngine;
using Unity.Sentis;
using System.IO;
using Unity.Sentis.ONNX;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
#if UNITY_EDITOR
Expand Down Expand Up @@ -47,7 +46,6 @@ public class ModelOverrider : MonoBehaviour
// Cached loaded ModelAssets, with the behavior name as the key.
Dictionary<string, ModelAsset> m_CachedModels = new Dictionary<string, ModelAsset>();


// Max episodes to run. Only used if > 0
// Will default to 1 if override models are specified, otherwise 0.
int m_MaxEpisodes;
Expand Down Expand Up @@ -120,6 +118,7 @@ void GetAssetPathFromCommandLine()
{
return;
}

var maxEpisodes = 0;
var timeoutSeconds = 0;

Expand Down Expand Up @@ -148,6 +147,7 @@ void GetAssetPathFromCommandLine()
EditorApplication.isPlaying = false;
#endif
}

m_OverrideExtensions.Add(overrideExtension);
}
else if (args[i] == k_CommandLineQuitAfterEpisodesFlag && i < args.Length - 1)
Expand Down Expand Up @@ -276,11 +276,23 @@ public ModelAsset GetModelForBehaviorName(string behaviorName)
if (rawModel == null)
{
Debug.Log($"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}");

// Cache the null so we don't repeatedly try to load a missing file
m_CachedModels[behaviorName] = null;
return null;
}

// TODO enable this when we have a decision on supporting loading/converting an ONNX model directly into a ModelAsset
// ModelAsset asset;
// if (isOnnx)
// {
// var modelName = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.onnx");
// asset = LoadOnnxModel(modelName);
// }
// else
// {
// asset = LoadSentisModel(rawModel);
// }
// var asset = isOnnx ? LoadOnnxModel(rawModel) : LoadSentisModel(rawModel);
var asset = LoadSentisModel(rawModel);
asset.name = assetName;
Expand All @@ -296,6 +308,41 @@ ModelAsset LoadSentisModel(byte[] rawModel)
return asset;
}

// TODO enable this when we have a decision on supporting loading/converting an ONNX model directly into a ModelAsset
// ModelAsset LoadOnnxModel(string modelName)
// {
// Debug.Log($"Loading model for override: {modelName}");
// var converter = new ONNXModelConverter(true);
// var directoryName = Path.GetDirectoryName(modelName);
// var model = converter.Convert(modelName, directoryName);
// var asset = ScriptableObject.CreateInstance<ModelAsset>();
// var assetData = ScriptableObject.CreateInstance<ModelAssetData>();
// var descStream = new MemoryStream();
// ModelWriter.SaveModelDesc(descStream, model);
// assetData.value = descStream.ToArray();
// assetData.name = "Data";
// assetData.hideFlags = HideFlags.HideInHierarchy;
// descStream.Close();
// descStream.Dispose();
// asset.modelAssetData = assetData;
// var weightStreams = new List<MemoryStream>();
// ModelWriter.SaveModelWeights(weightStreams, model);
//
// asset.modelWeightsChunks = new ModelAssetWeightsData[weightStreams.Count];
// for (int i = 0; i < weightStreams.Count; i++)
// {
// var stream = weightStreams[i];
// asset.modelWeightsChunks[i] = ScriptableObject.CreateInstance<ModelAssetWeightsData>();
// asset.modelWeightsChunks[i].value = stream.ToArray();
// asset.modelWeightsChunks[i].name = "Data";
// asset.modelWeightsChunks[i].hideFlags = HideFlags.HideInHierarchy;
// stream.Close();
// stream.Dispose();
// }
//
// return asset;
// }

// TODO this should probably be deprecated since Sentis does not support direct conversion from byte arrays
// ModelAsset LoadOnnxModel(byte[] rawModel)
// {
Expand All @@ -317,7 +364,6 @@ ModelAsset LoadSentisModel(byte[] rawModel)
// return asset;
// }


/// <summary>
/// Load the ModelAsset file from the specified path, and give it to the attached agent.
/// </summary>
Expand Down Expand Up @@ -369,12 +415,12 @@ void OverrideModel()
{
Debug.LogWarning(overrideError);
}

Application.Quit(1);
#if UNITY_EDITOR
EditorApplication.isPlaying = false;
#endif
}

}
}
}

0 comments on commit 7277a89

Please sign in to comment.