Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
using Microsoft.CodeAnalysis;
using MSBuildProjectCollection = Microsoft.Build.Evaluation.ProjectCollection;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.Simplification;
using Microsoft.CodeAnalysis.Text;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Utilities;
Expand Down Expand Up @@ -88,8 +90,6 @@ private async Task UpdateProject(Document document)
var root = await document.GetSyntaxRootAsync();
Debug.Assert(root != null);

root = root.WithAdditionalAnnotations(Simplifier.Annotation);
document = document.WithSyntaxRoot(root);
_project = document.Project;
}

Expand Down Expand Up @@ -148,7 +148,24 @@ private async Task<Document> ProcessDocument(Document document, MemberRemoverRew
}
document = document.WithSyntaxRoot(root);

document = await Simplifier.ReduceAsync(document);
document = await ReduceQualifiedNamesAsync(document);
root = await document.GetSyntaxRootAsync();
if (root == null)
{
return document;
}

var simplifierSpans = GetSimplifierSpans(root);
if (simplifierSpans.Count > 0)
{
root = root.WithAdditionalAnnotations(Simplifier.Annotation);
document = document.WithSyntaxRoot(root);
document = await Simplifier.ReduceAsync(document, simplifierSpans);
}
else if (ContainsSimplifierAnnotations(root))
{
document = await Simplifier.ReduceAsync(document);
}

// Reformat if any custom rewriters have been applied
if (CodeModelGenerator.Instance.Rewriters.Count > 0)
Expand All @@ -158,6 +175,133 @@ private async Task<Document> ProcessDocument(Document document, MemberRemoverRew
return document;
}

private static async Task<Document> ReduceQualifiedNamesAsync(Document document)
{
var root = await document.GetSyntaxRootAsync();
var semanticModel = await document.GetSemanticModelAsync();
if (root == null || semanticModel == null)
{
return document;
}

var safeNodes = new HashSet<NameSyntax>();
foreach (var name in root.DescendantNodes().OfType<NameSyntax>())
{
if (name is not QualifiedNameSyntax and not AliasQualifiedNameSyntax ||
name.Parent is QualifiedNameSyntax ||
IsPartOfMemberAccessChain(name) ||
IsInUnsupportedQualifiedNameContext(name))
{
continue;
}

var originalSymbol = semanticModel.GetSymbolInfo(name).Symbol;
if (originalSymbol == null)
{
continue;
}

var replacement = GetRightmostName(name).WithTriviaFrom(name);
if (SpeculativelyBindsToSameSymbol(semanticModel, name, replacement, originalSymbol))
{
safeNodes.Add(name);
}
}

if (safeNodes.Count == 0)
{
return document;
}

var rewrittenRoot = root.ReplaceNodes(
safeNodes,
static (_, rewritten) => GetRightmostName(rewritten).WithTriviaFrom(rewritten));
return document.WithSyntaxRoot(rewrittenRoot);
}

private static bool SpeculativelyBindsToSameSymbol(
SemanticModel semanticModel,
NameSyntax originalName,
SimpleNameSyntax replacement,
ISymbol originalSymbol)
{
var speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo(
originalName.SpanStart,
replacement,
SpeculativeBindingOption.BindAsTypeOrNamespace).Symbol;
if (speculativeSymbol != null &&
SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol))
{
return true;
}

if (originalName.Parent is MemberAccessExpressionSyntax memberAccess &&
memberAccess.Expression == originalName)
{
speculativeSymbol = semanticModel.GetSpeculativeSymbolInfo(
originalName.SpanStart,
replacement,
SpeculativeBindingOption.BindAsExpression).Symbol;
return speculativeSymbol != null &&
SymbolEqualityComparer.Default.Equals(originalSymbol, speculativeSymbol);
}

return false;
}

private static SimpleNameSyntax GetRightmostName(NameSyntax name) => name switch
{
QualifiedNameSyntax qualifiedName => qualifiedName.Right,
AliasQualifiedNameSyntax aliasQualifiedName => aliasQualifiedName.Name,
SimpleNameSyntax simpleName => simpleName,
_ => throw new InvalidOperationException($"Unexpected name syntax: {name.Kind()}")
};

private static bool IsPartOfMemberAccessChain(NameSyntax name) =>
name.Parent is MemberAccessExpressionSyntax ||
name.Ancestors().OfType<MemberAccessExpressionSyntax>().Any();

private static bool IsInUnsupportedQualifiedNameContext(NameSyntax name) =>
name.Ancestors().Any(static ancestor =>
ancestor is UsingDirectiveSyntax ||
ancestor is CrefSyntax);

private static bool ContainsSimplifierAnnotations(SyntaxNode root) =>
root.HasAnnotation(Simplifier.Annotation) ||
root.DescendantNodesAndTokens(descendIntoTrivia: true).Any(static nodeOrToken =>
nodeOrToken.HasAnnotation(Simplifier.Annotation));

private static IReadOnlyList<TextSpan> GetSimplifierSpans(SyntaxNode root)
{
List<TextSpan> spans = new();
foreach (var member in root.DescendantNodes().OfType<MemberDeclarationSyntax>())
{
if (ContainsReducibleSyntax(member))
{
spans.Add(member.FullSpan);
}
}

spans.AddRange(root
.DescendantNodesAndTokens(descendIntoTrivia: true)
.Where(static nodeOrToken => nodeOrToken.HasAnnotation(Simplifier.Annotation))
.Select(static nodeOrToken => nodeOrToken.FullSpan));

return spans;
}

private static bool ContainsReducibleSyntax(SyntaxNode root) =>
root.DescendantNodes(
descendIntoChildren: node => node == root || node is not MemberDeclarationSyntax,
descendIntoTrivia: true).Any(static node =>
node is ThisExpressionSyntax ||
node is ParenthesizedExpressionSyntax ||
node is CrefSyntax ||
node is QualifiedNameSyntax ||
node is MemberAccessExpressionSyntax ||
node is AssignmentExpressionSyntax { RawKind: (int)SyntaxKind.SimpleAssignmentExpression } ||
node is AliasQualifiedNameSyntax { Alias.Identifier.ValueText: "global" });

public static bool IsGeneratedDocument(Document document) => document.Folders.Contains(GeneratedFolder);
public static bool IsCustomDocument(Document document) => !IsGeneratedDocument(document);
public static bool IsGeneratedTestDocument(Document document) => document.Folders.Contains(GeneratedTestFolder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.Build.Construction;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Tests.Common;
using NUnit.Framework;
using System;
Expand Down Expand Up @@ -97,6 +98,72 @@ await MockHelpers.LoadMockGeneratorAsync(
Assert.NotNull(fooMethod, "Foo method should be found in the SimpleType");
}

[Test]
public async Task GetGeneratedFilesAsync_SimplifiesFrameworkNamesWhenTypeHasSystemMember()
{
MockHelpers.LoadMockGenerator(
outputPath: _projectDir,
configuration: "{\"package-name\": \"TestNamespace\"}",
additionalMetadataReferences:
[
MetadataReference.CreateFromFile(typeof(EditorBrowsableAttribute).Assembly.Location)
]);

GeneratedCodeWorkspace.Initialize();
var workspace = await GeneratedCodeWorkspace.Create(false);
await workspace.AddGeneratedFile(new CodeFile(
"""
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// <auto-generated/>

#nullable disable

using System;
using System.ComponentModel;

namespace TestNamespace
{
public readonly partial struct TestRole : IEquatable<TestRole>
{
private readonly string _value;
private const string SystemValue = "system";

public TestRole(string value)
{
_value = value;
}

public static TestRole System { get; } = new TestRole(SystemValue);

[global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)]
public override bool Equals(object obj) => obj is TestRole other && Equals(other);

public bool Equals(TestRole other) => string.Equals(_value, other._value, global::System.StringComparison.InvariantCultureIgnoreCase);

[global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)]
public override int GetHashCode() => _value != null ? global::System.StringComparer.InvariantCultureIgnoreCase.GetHashCode(_value) : 0;
}
}
""",
"TestRole.cs"));

string? generatedText = null;
await foreach (var generatedFile in workspace.GetGeneratedFilesAsync())
{
generatedText = generatedFile.Text;
}

Assert.That(generatedText, Is.Not.Null);
Assert.That(generatedText, Does.Contain("[EditorBrowsable(EditorBrowsableState.Never)]"));
Assert.That(generatedText, Does.Contain("StringComparison.InvariantCultureIgnoreCase"));
Assert.That(generatedText, Does.Contain("StringComparer.InvariantCultureIgnoreCase"));
Assert.That(generatedText, Does.Not.Contain("System.ComponentModel.EditorBrowsableAttribute"));
Assert.That(generatedText, Does.Not.Contain("System.StringComparison"));
Assert.That(generatedText, Does.Not.Contain("System.StringComparer"));
}

[Test]
public async Task AddPackageReferencesFromProject_AddsReferencesFromCsproj()
{
Expand Down
Loading