Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Add #[rustfmt::sort] and implement for enum variants and struct structs #6313

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,7 @@ fn rewrite_struct_lit<'a>(
v_shape,
mk_sp(body_lo, span.hi()),
one_line_width,
None,
)
.unknown_error()?
} else {
Expand Down
65 changes: 54 additions & 11 deletions src/items.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Formatting top-level items - functions, structs, enums, traits, impls.

use std::borrow::Cow;
use std::cmp::{Ordering, max, min};

use itertools::Itertools;
use regex::Regex;
use rustc_ast::visit;
use rustc_ast::{ast, ptr};
use rustc_span::{BytePos, DUMMY_SP, Span, symbol};
use std::borrow::Cow;
use std::cmp::{Ordering, max, min};

use crate::attr::filter_inline_attrs;
use crate::comment::{
Expand Down Expand Up @@ -519,13 +519,19 @@ impl<'a> FmtVisitor<'a> {
self.push_rewrite(static_parts.span, rewrite);
}

pub(crate) fn visit_struct(&mut self, struct_parts: &StructParts<'_>) {
pub(crate) fn visit_struct(&mut self, struct_parts: &StructParts<'_>, sort: bool) {
let is_tuple = match struct_parts.def {
ast::VariantData::Tuple(..) => true,
_ => false,
};
let rewrite = format_struct(&self.get_context(), struct_parts, self.block_indent, None)
.map(|s| if is_tuple { s + ";" } else { s });
let rewrite = format_struct(
&self.get_context(),
struct_parts,
self.block_indent,
None,
sort,
)
.map(|s| if is_tuple { s + ";" } else { s });
self.push_rewrite(struct_parts.span, rewrite);
}

Expand All @@ -536,6 +542,7 @@ impl<'a> FmtVisitor<'a> {
enum_def: &ast::EnumDef,
generics: &ast::Generics,
span: Span,
sort: bool,
) {
let enum_header =
format_header(&self.get_context(), "enum ", ident, vis, self.block_indent);
Expand Down Expand Up @@ -563,7 +570,7 @@ impl<'a> FmtVisitor<'a> {

self.last_pos = body_start;

match self.format_variant_list(enum_def, body_start, span.hi()) {
match self.format_variant_list(enum_def, body_start, span.hi(), sort) {
Some(ref s) if enum_def.variants.is_empty() => self.push_str(s),
rw => {
self.push_rewrite(mk_sp(body_start, span.hi()), rw);
Expand All @@ -578,6 +585,7 @@ impl<'a> FmtVisitor<'a> {
enum_def: &ast::EnumDef,
body_lo: BytePos,
body_hi: BytePos,
sort: bool,
) -> Option<String> {
if enum_def.variants.is_empty() {
let mut buffer = String::with_capacity(128);
Expand Down Expand Up @@ -615,7 +623,7 @@ impl<'a> FmtVisitor<'a> {
.unwrap_or(&0);

let itemize_list_with = |one_line_width: usize| {
itemize_list(
let iter = itemize_list(
self.snippet_provider,
enum_def.variants.iter(),
"}",
Expand All @@ -635,8 +643,16 @@ impl<'a> FmtVisitor<'a> {
body_lo,
body_hi,
false,
)
.collect()
);
if sort {
// sort the items by their name as this enum has the rustfmt::sort attr
iter.enumerate()
.sorted_by_key(|&(i, _)| enum_def.variants[i].ident.name.as_str())
.map(|(_, item)| item)
.collect()
} else {
iter.collect()
}
};
let mut items: Vec<_> = itemize_list_with(self.config.struct_variant_width());

Expand Down Expand Up @@ -695,6 +711,7 @@ impl<'a> FmtVisitor<'a> {
&StructParts::from_variant(field, &context),
self.block_indent,
Some(one_line_width),
false,
)?,
ast::VariantData::Unit(..) => rewrite_ident(&context, field.ident).to_owned(),
};
Expand Down Expand Up @@ -1143,14 +1160,15 @@ fn format_struct(
struct_parts: &StructParts<'_>,
offset: Indent,
one_line_width: Option<usize>,
sort: bool,
) -> Option<String> {
match struct_parts.def {
ast::VariantData::Unit(..) => format_unit_struct(context, struct_parts, offset),
ast::VariantData::Tuple(fields, _) => {
format_tuple_struct(context, struct_parts, fields, offset)
}
ast::VariantData::Struct { fields, .. } => {
format_struct_struct(context, struct_parts, fields, offset, one_line_width)
format_struct_struct(context, struct_parts, fields, offset, one_line_width, sort)
}
}
}
Expand Down Expand Up @@ -1429,6 +1447,7 @@ pub(crate) fn format_struct_struct(
fields: &[ast::FieldDef],
offset: Indent,
one_line_width: Option<usize>,
sort: bool,
) -> Option<String> {
let mut result = String::with_capacity(1024);
let span = struct_parts.span;
Expand Down Expand Up @@ -1497,12 +1516,36 @@ pub(crate) fn format_struct_struct(
let one_line_budget =
one_line_width.map_or(0, |one_line_width| min(one_line_width, one_line_budget));

let ranks: Option<Vec<_>> = if sort {
// get the sequence of indices that would sort the vec
let indices: Vec<usize> = fields
.iter()
.enumerate()
.sorted_by(|(_, field_a), (_, field_b)| {
field_a
.ident
.zip(field_b.ident)
.map(|(a, b)| a.name.as_str().cmp(b.name.as_str()))
.unwrap_or(Ordering::Equal)
})
.map(|(i, _)| i)
.collect();
// create a vec with ranks for the fields, allowing for use in Itertools.sorted_by_key
let mut ranks = vec![0; indices.len()];
for (rank, original_index) in indices.into_iter().enumerate() {
ranks[original_index] = rank;
}
Some(ranks)
} else {
None
};
let items_str = rewrite_with_alignment(
fields,
context,
Shape::indented(offset.block_indent(context.config), context.config).sub_width(1)?,
mk_sp(body_lo, span.hi()),
one_line_budget,
ranks.as_ref().map(|v| v.as_slice()),
)?;

if !items_str.contains('\n')
Expand Down
11 changes: 11 additions & 0 deletions src/skip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl SkipNameContext {

static RUSTFMT: &str = "rustfmt";
static SKIP: &str = "skip";
static SORT: &str = "sort";

/// Say if you're playing with `rustfmt`'s skip attribute
pub(crate) fn is_skip_attr(segments: &[ast::PathSegment]) -> bool {
Expand All @@ -103,6 +104,16 @@ pub(crate) fn is_skip_attr(segments: &[ast::PathSegment]) -> bool {
}
}

pub(crate) fn is_sort_attr(segments: &[ast::PathSegment]) -> bool {
if segments.len() < 2 || segments[0].ident.to_string() != RUSTFMT {
return false;
}
match segments.len() {
2 => segments[1].ident.to_string() == SORT,
_ => false,
}
}

fn get_skip_names(kind: &str, attrs: &[ast::Attribute]) -> Vec<String> {
let mut skip_names = vec![];
let path = format!("{RUSTFMT}::{SKIP}::{kind}");
Expand Down
34 changes: 34 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ pub(crate) fn skip_annotation() -> Symbol {
Symbol::intern("rustfmt::skip")
}

#[inline]
pub(crate) fn sort_annotation() -> Symbol {
Symbol::intern("rustfmt::sort")
}

pub(crate) fn rewrite_ident<'a>(context: &'a RewriteContext<'_>, ident: symbol::Ident) -> &'a str {
context.snippet(ident.span)
}
Expand Down Expand Up @@ -271,6 +276,35 @@ pub(crate) fn contains_skip(attrs: &[Attribute]) -> bool {
.any(|a| a.meta().map_or(false, |a| is_skip(&a)))
}

#[inline]
pub(crate) fn contains_sort(attrs: &[Attribute]) -> bool {
attrs
.iter()
.any(|a| a.meta().map_or(false, |a| is_sort(&a)))
}

#[inline]
fn is_sort(meta_item: &MetaItem) -> bool {
match meta_item.kind {
MetaItemKind::Word => {
let path_str = pprust::path_to_string(&meta_item.path);
path_str == sort_annotation().as_str()
}
MetaItemKind::List(ref l) => {
meta_item.has_name(sym::cfg_attr) && l.len() == 2 && crate::utils::is_sort_nested(&l[1])
}
_ => false,
}
}

#[inline]
fn is_sort_nested(meta_item: &NestedMetaItem) -> bool {
match meta_item {
NestedMetaItem::MetaItem(ref mi) => crate::utils::is_sort(mi),
NestedMetaItem::Lit(_) => false,
}
}

#[inline]
pub(crate) fn semicolon_for_expr(context: &RewriteContext<'_>, expr: &ast::Expr) -> bool {
// Never try to insert semicolons on expressions when we're inside
Expand Down
20 changes: 19 additions & 1 deletion src/vertical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub(crate) fn rewrite_with_alignment<T: AlignedItem>(
shape: Shape,
span: Span,
one_line_width: usize,
ranks: Option<&[usize]>,
) -> Option<String> {
let (spaces, group_index) = if context.config.struct_field_align_threshold() > 0 {
group_aligned_items(context, fields)
Expand Down Expand Up @@ -170,12 +171,20 @@ pub(crate) fn rewrite_with_alignment<T: AlignedItem>(
shape.indent,
one_line_width,
force_separator,
ranks.map(|v| &v[0..=group_index]),
)?;
if rest.is_empty() {
Some(result + spaces)
} else {
let rest_span = mk_sp(init_last_pos, span.hi());
let rest_str = rewrite_with_alignment(rest, context, shape, rest_span, one_line_width)?;
let rest_str = rewrite_with_alignment(
rest,
context,
shape,
rest_span,
one_line_width,
ranks.map(|v| &v[group_index + 1..]),
)?;
Some(format!(
"{}{}\n{}{}",
result,
Expand Down Expand Up @@ -211,6 +220,7 @@ fn rewrite_aligned_items_inner<T: AlignedItem>(
offset: Indent,
one_line_width: usize,
force_trailing_separator: bool,
ranks: Option<&[usize]>,
) -> Option<String> {
// 1 = ","
let item_shape = Shape::indented(offset, context.config).sub_width(1)?;
Expand Down Expand Up @@ -266,6 +276,14 @@ fn rewrite_aligned_items_inner<T: AlignedItem>(
.tactic(tactic)
.trailing_separator(separator_tactic)
.preserve_newline(true);
if let Some(ranks) = ranks {
items = ranks
.iter()
.zip(items.into_iter())
.sorted_by_key(|&(index, _)| *index)
.map(|(_, item)| item)
.collect();
}
write_list(&items, &fmt).ok()
}

Expand Down
20 changes: 14 additions & 6 deletions src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ use crate::modules::Module;
use crate::parse::session::ParseSess;
use crate::rewrite::{Rewrite, RewriteContext};
use crate::shape::{Indent, Shape};
use crate::skip::{SkipContext, is_skip_attr};
use crate::skip::{SkipContext, is_skip_attr, is_sort_attr};
use crate::source_map::{LineRangeUtils, SpanUtils};
use crate::spanned::Spanned;
use crate::stmt::Stmt;
use crate::utils::{
self, contains_skip, count_newlines, depr_skip_annotation, format_safety, inner_attributes,
last_line_width, mk_sp, ptr_vec_to_ref_vec, rewrite_ident, starts_with_newline, stmt_expr,
self, contains_skip, contains_sort, count_newlines, depr_skip_annotation, format_safety,
inner_attributes, last_line_width, mk_sp, ptr_vec_to_ref_vec, rewrite_ident,
starts_with_newline, stmt_expr,
};
use crate::{ErrorKind, FormatReport, FormattingError};

Expand Down Expand Up @@ -511,11 +512,18 @@ impl<'b, 'a: 'b> FmtVisitor<'a> {
self.push_rewrite(span, rw);
}
ast::ItemKind::Struct(..) | ast::ItemKind::Union(..) => {
self.visit_struct(&StructParts::from_item(item));
self.visit_struct(&StructParts::from_item(item), contains_sort(&item.attrs));
}
ast::ItemKind::Enum(ref def, ref generics) => {
self.format_missing_with_indent(source!(self, item.span).lo());
self.visit_enum(item.ident, &item.vis, def, generics, item.span);
self.visit_enum(
item.ident,
&item.vis,
def,
generics,
item.span,
contains_sort(&item.attrs),
);
self.last_pos = source!(self, item.span).hi();
}
ast::ItemKind::Mod(safety, ref mod_kind) => {
Expand Down Expand Up @@ -858,7 +866,7 @@ impl<'b, 'a: 'b> FmtVisitor<'a> {
if segments[0].ident.to_string() != "rustfmt" {
return false;
}
!is_skip_attr(segments)
!(is_skip_attr(segments) | is_sort_attr(segments))
}

fn walk_mod_items(&mut self, items: &[rustc_ast::ptr::P<ast::Item>]) {
Expand Down
18 changes: 18 additions & 0 deletions tests/source/enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,21 @@ pub enum E {
A { a: u32 } = 0x100,
B { field1: u32, field2: u8, field3: m::M } = 0x300 // comment
}

// #3422
#[rustfmt::sort]
enum SortE {

Y, // whitespace from above dropped

X, // whitespace from above dropped
E,
// something
D(),
C(),
/// Comment for B
B,
/// Comment for A
#[rustfmt::skip]
A,
}
21 changes: 21 additions & 0 deletions tests/source/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,24 @@ struct Test {
// #2818
struct Paren((i32)) where i32: Trait;
struct Parens((i32, i32)) where i32: Trait;

// #3422
#[rustfmt::sort]
struct Foo {

#[skip]
b: u32,
a: u32, // a

bb: u32,
/// A
aa: u32,
}

#[rustfmt::sort]
struct Fooy {
a: u32, // a
b: u32,
/// C
c: u32,
}
Loading
Loading