Skip to content

Commit

Permalink
Hotfix for resolving scripting shells through DI.
Browse files Browse the repository at this point in the history
  • Loading branch information
bitbound committed Jul 20, 2023
1 parent b3af151 commit 72c393d
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 156 deletions.
2 changes: 1 addition & 1 deletion Agent/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private static void RegisterServices(IServiceCollection services)
services.AddSingleton<IWakeOnLanService, WakeOnLanService>();
services.AddHostedService(services => services.GetRequiredService<ICpuUtilizationSampler>());
services.AddScoped<IChatClientService, ChatClientService>();
services.AddTransient<IPSCore, PSCore>();
services.AddTransient<IPsCoreShell, PsCoreShell>();
services.AddTransient<IExternalScriptingShell, ExternalScriptingShell>();
services.AddScoped<IConfigService, ConfigService>();
services.AddScoped<IUninstaller, Uninstaller>();
Expand Down
2 changes: 1 addition & 1 deletion Agent/Services/AgentHubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ private void RegisterMessageHandlers()
{
try
{
var session = PSCore.GetCurrent(senderConnectionId);
var session = PsCoreShell.GetCurrent(senderConnectionId);
var completion = session.GetCompletions(inputText, currentIndex, forward);
var completionModel = completion.ToPwshCompletion();
await _hubConnection.InvokeAsync("ReturnPowerShellCompletions", completionModel, intent, senderConnectionId).ConfigureAwait(false);
Expand Down
242 changes: 122 additions & 120 deletions Agent/Services/ExternalScriptingShell.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,40 @@
using Microsoft.Extensions.Logging;
using Remotely.Shared.Enums;
using Remotely.Shared.Models;
using Remotely.Shared.Utilities;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Remotely.Agent.Services
{
public interface IExternalScriptingShell
{
ScriptResult WriteInput(string input, TimeSpan timeout);
Process ShellProcess { get; }
Task Init(ScriptingShell shell, string shellProcessName, string lineEnding, string connectionId);
Task<ScriptResult> WriteInput(string input, TimeSpan timeout);
}

public class ExternalScriptingShell : IExternalScriptingShell
{
private static readonly ConcurrentDictionary<string, ExternalScriptingShell> _sessions = new();
private static readonly ConcurrentDictionary<string, IExternalScriptingShell> _sessions = new();
private readonly IConfigService _configService;
private readonly ILogger<ExternalScriptingShell> _logger;
private readonly ManualResetEvent _outputDone = new(false);
private readonly SemaphoreSlim _writeLock = new(1, 1);
private string _errorOut = string.Empty;
private string _lastInputID = string.Empty;
private string _lineEnding;
private System.Timers.Timer _processIdleTimeout = new(TimeSpan.FromMinutes(10))
{
AutoReset = false
};

private string _senderConnectionId;
private ScriptingShell _shell;
private string _standardOut = string.Empty;

public ExternalScriptingShell(
IConfigService configService,
Expand All @@ -34,47 +44,31 @@ public class ExternalScriptingShell : IExternalScriptingShell
_configService = configService;
_logger = logger;
}
public Process ShellProcess { get; set; }

private string ErrorOut { get; set; }

private string LastInputID { get; set; }

private ManualResetEvent OutputDone { get; } = new(false);

private System.Timers.Timer ProcessIdleTimeout { get; set; }

private string SenderConnectionId { get; set; }

private Process ShellProcess { get; set; }

private string StandardOut { get; set; }

private Stopwatch Stopwatch { get; set; }

// TODO: Turn into cache and factory.
public static ExternalScriptingShell GetCurrent(ScriptingShell shell, string senderConnectionId)
public static async Task<IExternalScriptingShell> GetCurrent(ScriptingShell shell, string senderConnectionId)
{
if (_sessions.TryGetValue($"{shell}-{senderConnectionId}", out var session) &&
session.ShellProcess?.HasExited != true)
{
session.ProcessIdleTimeout.Stop();
session.ProcessIdleTimeout.Start();
return session;
}
else
{
session = Program.Services.GetRequiredService<ExternalScriptingShell>();
session = Program.Services.GetRequiredService<IExternalScriptingShell>();

switch (shell)
{
case ScriptingShell.WinPS:
session.Init(shell, "powershell.exe", "\r\n", senderConnectionId);
await session.Init(shell, "powershell.exe", "\r\n", senderConnectionId);
break;
case ScriptingShell.Bash:
session.Init(shell, "bash", "\n", senderConnectionId);
await session.Init(shell, "bash", "\n", senderConnectionId);
break;
case ScriptingShell.CMD:
session.Init(shell, "cmd.exe", "\r\n", senderConnectionId);
await session.Init(shell, "cmd.exe", "\r\n", senderConnectionId);
break;
default:
throw new ArgumentException($"Unknown external scripting shell type: {shell}");
Expand All @@ -84,135 +78,143 @@ public static ExternalScriptingShell GetCurrent(ScriptingShell shell, string sen
}
}

public ScriptResult WriteInput(string input, TimeSpan timeout)
public async Task Init(ScriptingShell shell, string shellProcessName, string lineEnding, string connectionId)
{
_shell = shell;
_lineEnding = lineEnding;
_senderConnectionId = connectionId;

var psi = new ProcessStartInfo(shellProcessName)
{
WindowStyle = ProcessWindowStyle.Hidden,
Verb = "RunAs",
UseShellExecute = false,
RedirectStandardError = true,
RedirectStandardInput = true,
RedirectStandardOutput = true
};

var connectionInfo = _configService.GetConnectionInfo();
psi.Environment.Add("DeviceId", connectionInfo.DeviceID);
psi.Environment.Add("ServerUrl", connectionInfo.Host);

ShellProcess = new Process
{
StartInfo = psi
};
ShellProcess.ErrorDataReceived += ShellProcess_ErrorDataReceived;
ShellProcess.OutputDataReceived += ShellProcess_OutputDataReceived;

ShellProcess.Start();

ShellProcess.BeginErrorReadLine();
ShellProcess.BeginOutputReadLine();

_processIdleTimeout = new System.Timers.Timer(TimeSpan.FromMinutes(10))
{
AutoReset = false
};
_processIdleTimeout.Elapsed += ProcessIdleTimeout_Elapsed;
_processIdleTimeout.Start();

if (shell == ScriptingShell.WinPS)
{
await WriteInput("$VerbosePreference = \"Continue\";", TimeSpan.FromSeconds(5));
await WriteInput("$DebugPreference = \"Continue\";", TimeSpan.FromSeconds(5));
await WriteInput("$InformationPreference = \"Continue\";", TimeSpan.FromSeconds(5));
await WriteInput("$WarningPreference = \"Continue\";", TimeSpan.FromSeconds(5));
}
}

public async Task<ScriptResult> WriteInput(string input, TimeSpan timeout)
{
await _writeLock.WaitAsync();
var sw = Stopwatch.StartNew();

try
{
StandardOut = "";
ErrorOut = "";
Stopwatch = Stopwatch.StartNew();
lock (ShellProcess)
{
LastInputID = Guid.NewGuid().ToString();
OutputDone.Reset();
ShellProcess.StandardInput.Write(input + _lineEnding);
ShellProcess.StandardInput.Write("echo " + LastInputID + _lineEnding);

var result = Task.WhenAny(
Task.Run(() =>
{
return ShellProcess.WaitForExit((int)timeout.TotalMilliseconds);
}),
Task.Run(() =>
{
return OutputDone.WaitOne();
})).ConfigureAwait(false).GetAwaiter().GetResult();

if (!result.Result)
_processIdleTimeout.Stop();
_processIdleTimeout.Start();
_outputDone.Reset();

_standardOut = "";
_errorOut = "";
_lastInputID = Guid.NewGuid().ToString();

ShellProcess.StandardInput.Write(input + _lineEnding);
ShellProcess.StandardInput.Write("echo " + _lastInputID + _lineEnding);

var result = await Task.WhenAny(
Task.Run(() =>
{
return ShellProcess.WaitForExit((int)timeout.TotalMilliseconds);
}),
Task.Run(() =>
{
return GeneratePartialResult(input);
}
return _outputDone.WaitOne();
})).ConfigureAwait(false).GetAwaiter().GetResult();

if (!result)
{
return GeneratePartialResult(input, sw.Elapsed);
}
return GenerateCompletedResult(input);

return GenerateCompletedResult(input, sw.Elapsed);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error while writing input to scripting shell.");
ErrorOut += Environment.NewLine + ex.Message;
_errorOut += Environment.NewLine + ex.Message;

// Something's wrong. Let the next command start a new session.
RemoveSession();
}
finally
{
_writeLock.Release();
}

return GeneratePartialResult(input);
return GeneratePartialResult(input, sw.Elapsed);
}

private ScriptResult GenerateCompletedResult(string input)
private ScriptResult GenerateCompletedResult(string input, TimeSpan runtime)
{
return new ScriptResult()
{
Shell = _shell,
RunTime = Stopwatch.Elapsed,
RunTime = runtime,
ScriptInput = input,
SenderConnectionID = SenderConnectionId,
SenderConnectionID = _senderConnectionId,
DeviceID = _configService.GetConnectionInfo().DeviceID,
StandardOutput = StandardOut.Split(Environment.NewLine),
ErrorOutput = ErrorOut.Split(Environment.NewLine),
HadErrors = !string.IsNullOrWhiteSpace(ErrorOut) ||
StandardOutput = _standardOut.Split(Environment.NewLine),
ErrorOutput = _errorOut.Split(Environment.NewLine),
HadErrors = !string.IsNullOrWhiteSpace(_errorOut) ||
(ShellProcess.HasExited && ShellProcess.ExitCode != 0)
};
}

private ScriptResult GeneratePartialResult(string input)
private ScriptResult GeneratePartialResult(string input, TimeSpan runtime)
{
var partialResult = new ScriptResult()
{
Shell = _shell,
RunTime = Stopwatch.Elapsed,
RunTime = runtime,
ScriptInput = input,
SenderConnectionID = SenderConnectionId,
SenderConnectionID = _senderConnectionId,
DeviceID = _configService.GetConnectionInfo().DeviceID,
StandardOutput = StandardOut.Split(Environment.NewLine),
StandardOutput = _standardOut.Split(Environment.NewLine),
ErrorOutput = (new[] { "WARNING: The command execution timed out and was forced to return before finishing. " +
"The results may be partial, and the terminal process has been reset. " +
"Please note that interactive commands aren't supported."})
.Concat(ErrorOut.Split(Environment.NewLine))
.Concat(_errorOut.Split(Environment.NewLine))
.ToArray(),
HadErrors = !string.IsNullOrWhiteSpace(ErrorOut) ||
HadErrors = !string.IsNullOrWhiteSpace(_errorOut) ||
(ShellProcess.HasExited && ShellProcess.ExitCode != 0)
};
ProcessIdleTimeout_Elapsed(this, null);
return partialResult;
}

private void Init(ScriptingShell shell, string shellProcessName, string lineEnding, string connectionId)
{
_shell = shell;
_lineEnding = lineEnding;
SenderConnectionId = connectionId;

var psi = new ProcessStartInfo(shellProcessName)
{
WindowStyle = ProcessWindowStyle.Hidden,
Verb = "RunAs",
UseShellExecute = false,
RedirectStandardError = true,
RedirectStandardInput = true,
RedirectStandardOutput = true
};

var connectionInfo = _configService.GetConnectionInfo();
psi.Environment.Add("DeviceId", connectionInfo.DeviceID);
psi.Environment.Add("ServerUrl", connectionInfo.Host);

ShellProcess = new Process
{
StartInfo = psi
};
ShellProcess.ErrorDataReceived += ShellProcess_ErrorDataReceived;
ShellProcess.OutputDataReceived += ShellProcess_OutputDataReceived;

ShellProcess.Start();

ShellProcess.BeginErrorReadLine();
ShellProcess.BeginOutputReadLine();

ProcessIdleTimeout = new System.Timers.Timer(TimeSpan.FromMinutes(10).TotalMilliseconds)
{
AutoReset = false
};
ProcessIdleTimeout.Elapsed += ProcessIdleTimeout_Elapsed;
ProcessIdleTimeout.Start();

if (shell == ScriptingShell.WinPS)
{
WriteInput("$VerbosePreference = \"Continue\";", TimeSpan.FromSeconds(5));
WriteInput("$DebugPreference = \"Continue\";", TimeSpan.FromSeconds(5));
WriteInput("$InformationPreference = \"Continue\";", TimeSpan.FromSeconds(5));
WriteInput("$WarningPreference = \"Continue\";", TimeSpan.FromSeconds(5));
}
}
private void ProcessIdleTimeout_Elapsed(object sender, System.Timers.ElapsedEventArgs e)
{
RemoveSession();
Expand All @@ -221,26 +223,26 @@ private void ProcessIdleTimeout_Elapsed(object sender, System.Timers.ElapsedEven
private void RemoveSession()
{
ShellProcess?.Kill();
_sessions.TryRemove(SenderConnectionId, out _);
_sessions.TryRemove(_senderConnectionId, out _);
}

private void ShellProcess_ErrorDataReceived(object sender, DataReceivedEventArgs e)
{
if (e?.Data != null)
{
ErrorOut += e.Data + Environment.NewLine;
_errorOut += e.Data + Environment.NewLine;
}
}

private void ShellProcess_OutputDataReceived(object sender, DataReceivedEventArgs e)
{
if (e?.Data?.Contains(LastInputID) == true)
if (e?.Data?.Contains(_lastInputID) == true)
{
OutputDone.Set();
_outputDone.Set();
}
else
{
StandardOut += e.Data + Environment.NewLine;
_standardOut += e.Data + Environment.NewLine;
}

}
Expand Down

0 comments on commit 72c393d

Please sign in to comment.