You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
kev/Drawer/AI/Utils/SpecialMessageDeserializer.cs

637 lines
24 KiB
C#

1 month ago
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using AI.Interface;
using AI.Models;
using AI.Models.Form;
using AI.Models.SpecialMessages;
using YamlDotNet.Core;
using YamlDotNet.Serialization;
using YamlDotNet.Serialization.NamingConventions;
namespace AI.Utils
{
public static class SpecialMessageTypes
{
public const string Form = "Form";
public const string ParameterSet = "ParameterSet";
public const string Table = "Table";
public const string ColumnMatch = "ColumnMatch";
public const string WorkflowStatus = "WorkflowStatus";
public const string KnowledgeBase = "KnowledgeBase";
public const string XyzLoadCard = "XyzLoadCard";
public const string GriddingParamCard = "GriddingParamCard";
}
/// <summary>
/// 根据 type + YAML 载荷还原 ISpecialMessage。Form 需 IFormRegistry 根据 formId 取 FormDefinition。
/// formId 缺失或 Definition 不存在时降级为占位或只读文本。
/// </summary>
public static class SpecialMessageDeserializer
{
public static ISpecialMessage? Deserialize(string type, string yamlPayload, IFormRegistry? formRegistry = null)
{
if (string.IsNullOrWhiteSpace(type) || string.IsNullOrWhiteSpace(yamlPayload))
{
return null;
}
try
{
return type.Trim() switch
{
SpecialMessageTypes.Form => DeserializeForm(yamlPayload, formRegistry),
SpecialMessageTypes.ParameterSet => DeserializeParameterSet(yamlPayload),
SpecialMessageTypes.Table => DeserializeTable(yamlPayload),
SpecialMessageTypes.ColumnMatch => DeserializeColumnMatch(yamlPayload),
SpecialMessageTypes.WorkflowStatus => DeserializeWorkflowStatus(yamlPayload),
SpecialMessageTypes.KnowledgeBase => DeserializeKnowledgeBase(yamlPayload),
SpecialMessageTypes.XyzLoadCard => DeserializeXyzLoadCard(yamlPayload),
SpecialMessageTypes.GriddingParamCard => DeserializeGriddingParamCard(yamlPayload),
_ => null
};
}
catch (Exception ex)
{
System.Diagnostics.Debug.WriteLine($"[SpecialMessageDeserializer] 反序列化失败: type={type}, error={ex}");
return null;
}
}
private static readonly IReadOnlyDictionary<string, (string TypeName, string[] DetailKeys)> ShortDescriptionTable =
new Dictionary<string, (string, string[])>(StringComparer.OrdinalIgnoreCase)
{
[SpecialMessageTypes.Form] = ("表单", new[] { "title", "formId" }),
[SpecialMessageTypes.ParameterSet] = ("参数集", new[] { "title" }),
[SpecialMessageTypes.Table] = ("表格", new[] { "title" }),
[SpecialMessageTypes.ColumnMatch] = ("列匹配", new[] { "title" }),
[SpecialMessageTypes.WorkflowStatus] = ("工作流状态", new[] { "title" }),
[SpecialMessageTypes.XyzLoadCard] = ("散点文件加载卡片", new[] { "filePath" }),
[SpecialMessageTypes.GriddingParamCard] = ("网格化参数卡片", Array.Empty<string>()),
};
/// <summary>
/// 从特殊消息的 type + YAML 载荷生成发给 LLM 的短描述(如「[已展示:表单「加载散点」]」),便于模型理解上下文。
/// 对于 KnowledgeBase 类型,返回完整 rawContent以便后续轮次 LLM 能看到查到的知识。
/// </summary>
public static string GetShortDescriptionForLlm(string type, string payload)
{
if (string.IsNullOrWhiteSpace(type))
{
return "[已展示:未知]";
}
if (string.IsNullOrWhiteSpace(payload))
{
return $"[已展示:{type}]";
}
try
{
var dict = FromYaml(payload);
if (dict == null)
{
return $"[已展示:{type}]";
}
// 知识库条目:把完整内容带给 LLM作为后续对话的参考
if (string.Equals(type.Trim(), SpecialMessageTypes.KnowledgeBase, StringComparison.OrdinalIgnoreCase))
{
var rawContent = GetString(dict, "rawContent");
return string.IsNullOrWhiteSpace(rawContent)
? "[已查询知识库,无匹配内容]"
: $"[来自知识库的参考]{System.Environment.NewLine}{rawContent}";
}
string typeName = type;
string? detail = null;
if (ShortDescriptionTable.TryGetValue(type.Trim(), out var row))
{
typeName = row.TypeName;
foreach (var key in row.DetailKeys)
{
var value = GetString(dict, key);
if (!string.IsNullOrWhiteSpace(value))
{
detail = value;
break;
}
}
}
if (!string.IsNullOrWhiteSpace(detail))
{
return $"[已展示:{typeName}「{detail}」]";
}
return $"[已展示:{typeName}]";
}
catch
{
return $"[已展示:{type}]";
}
}
private static ISpecialMessage? DeserializeForm(string yaml, IFormRegistry? formRegistry)
{
var dict = FromYaml(yaml);
if (dict == null)
{
return null;
}
string formId = GetString(dict, "formId") ?? string.Empty;
FormDefinition? definition = null;
if (formRegistry != null && !string.IsNullOrEmpty(formId))
{
definition = formRegistry.GetForm(formId);
}
if (definition == null)
{
return null; // 降级:无 FormRegistry 或 formId 不存在时不还原表单卡片
}
var formMsg = new FormRequestMessage(definition);
if (TryGet(dict, "id", out var idObj))
{
formMsg.Id = idObj?.ToString() ?? formMsg.Id;
}
if (TryGet(dict, "submitLabel", out var submitObj))
{
formMsg.SubmitLabel = submitObj?.ToString() ?? formMsg.SubmitLabel;
}
if (TryGet(dict, "fields", out var fieldsObj) && fieldsObj is IEnumerable<object> fieldsList)
{
foreach (var fe in fieldsList)
{
if (fe is not Dictionary<object, object> fieldDict)
{
continue;
}
string fid = GetString(fieldDict, "id") ?? string.Empty;
var entry = formMsg.FieldsWithValues.FirstOrDefault(f => string.Equals(f.Id, fid, StringComparison.OrdinalIgnoreCase));
if (entry == null)
{
continue;
}
entry.CurrentValue = GetString(fieldDict, "currentValue") ?? string.Empty;
if (TryGet(fieldDict, "selectedValues", out var svObj) && svObj is IEnumerable<object> svList)
{
entry.SelectedValues.Clear();
foreach (var v in svList)
{
if (v?.ToString() is string s)
{
entry.SelectedValues.Add(s);
}
}
}
}
}
return formMsg;
}
private static ISpecialMessage? DeserializeParameterSet(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null)
{
return null;
}
var param = new ParameterSetMessage();
if (TryGet(dict, "id", out var idObj))
{
param.Id = idObj?.ToString() ?? param.Id;
}
param.Title = GetString(dict, "title") ?? param.Title;
if (TryGet(dict, "items", out var itemsObj) && itemsObj is IEnumerable<object> itemsList)
{
foreach (var item in itemsList)
{
if (item is not Dictionary<object, object> id)
{
continue;
}
var name = GetString(id, "name") ?? string.Empty;
var valueText = GetString(id, "valueText") ?? string.Empty;
var description = GetString(id, "description") ?? string.Empty;
var fieldTypeStr = GetString(id, "fieldType");
var fieldType = ParameterSetFieldType.Text;
if (!string.IsNullOrEmpty(fieldTypeStr) && Enum.TryParse<ParameterSetFieldType>(fieldTypeStr, true, out var ft))
{
fieldType = ft;
}
var options = new List<string>();
if (TryGet(id, "options", out var optObj) && optObj is IEnumerable<object> optList)
{
foreach (var o in optList)
{
if (o?.ToString() is string s)
{
options.Add(s);
}
}
}
param.Items.Add(new ParameterSetItem
{
Name = name,
ValueText = valueText,
Description = description,
FieldType = fieldType,
Options = options
});
}
}
return param;
}
private static ISpecialMessage? DeserializeTable(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null)
{
return null;
}
var table = new TableDataMessage();
if (TryGet(dict, "id", out var idObj))
{
table.Id = idObj?.ToString() ?? table.Id;
}
table.Title = GetString(dict, "title") ?? table.Title;
table.TotalRowCount = GetInt(dict, "totalRowCount") ?? 0;
table.MaxPreviewRows = GetInt(dict, "maxPreviewRows") ?? 50;
var columnNames = new List<string>();
if (TryGet(dict, "columnNames", out var cnObj) && cnObj is IEnumerable<object> cnList)
{
foreach (var c in cnList)
{
if (c?.ToString() is string s)
{
columnNames.Add(s);
}
}
}
var rows = new List<IEnumerable<string>>();
if (TryGet(dict, "rows", out var rowsObj) && rowsObj is IEnumerable<object> rowsList)
{
foreach (var rowObj in rowsList)
{
if (rowObj is IEnumerable<object> cells)
{
rows.Add(cells.Select(c => c?.ToString() ?? string.Empty).ToList());
}
}
}
table.SetData(columnNames, rows, table.TotalRowCount > 0 ? table.TotalRowCount : rows.Count);
return table;
}
private static ISpecialMessage? DeserializeColumnMatch(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null)
{
return null;
}
var col = new ColumnMatchMessage();
if (TryGet(dict, "id", out var idObj))
{
col.Id = idObj?.ToString() ?? col.Id;
}
col.Title = GetString(dict, "title") ?? col.Title;
var required = new List<string>();
if (TryGet(dict, "requiredColumns", out var rcObj) && rcObj is IEnumerable<object> rcList)
{
foreach (var c in rcList)
{
if (c?.ToString() is string s)
{
required.Add(s);
}
}
}
var preview = new List<string>();
if (TryGet(dict, "previewColumns", out var pcObj) && pcObj is IEnumerable<object> pcList)
{
foreach (var c in pcList)
{
if (c?.ToString() is string s)
{
preview.Add(s);
}
}
}
col.SetColumns(required, preview);
if (TryGet(dict, "mappings", out var mapObj) && mapObj is IEnumerable<object> mapList)
{
var mappings = new List<KeyValuePair<string, string>>();
foreach (var m in mapList)
{
if (m is not Dictionary<object, object> md)
{
continue;
}
string req = GetString(md, "requiredColumn") ?? string.Empty;
string mat = GetString(md, "matchedColumn") ?? string.Empty;
mappings.Add(new KeyValuePair<string, string>(req, mat));
}
col.SetMappings(mappings);
}
return col;
}
private static ISpecialMessage? DeserializeKnowledgeBase(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null)
{
return null;
}
var kb = new KnowledgeBaseMessage();
if (TryGet(dict, "id", out var idObj))
{
kb.Id = idObj?.ToString() ?? kb.Id;
}
kb.KnowledgeId = GetString(dict, "knowledgeId") ?? string.Empty;
kb.Version = GetString(dict, "version") ?? "1";
kb.TopicKey = GetString(dict, "topicKey") ?? string.Empty;
kb.Query = GetString(dict, "query") ?? string.Empty;
kb.RawContent = GetString(dict, "rawContent") ?? string.Empty;
return kb;
}
private static ISpecialMessage? DeserializeWorkflowStatus(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null)
{
return null;
}
var wf = new WorkflowStatusMessage();
if (TryGet(dict, "id", out var idObj))
{
wf.Id = idObj?.ToString() ?? wf.Id;
}
wf.Title = GetString(dict, "title") ?? wf.Title;
if (TryGet(dict, "steps", out var stepsObj) && stepsObj is IEnumerable<object> stepsList)
{
var steps = new ObservableCollection<WorkflowStepModel>();
foreach (var s in stepsList)
{
if (s is not Dictionary<object, object> sd)
{
continue;
}
var step = new WorkflowStepModel
{
Id = GetString(sd, "id") ?? string.Empty,
DisplayName = GetString(sd, "displayName") ?? string.Empty,
Order = GetInt(sd, "order") ?? 0,
OutputResult = GetString(sd, "outputResult"),
Thought = GetString(sd, "thought")
};
var statusStr = GetString(sd, "status");
if (!string.IsNullOrEmpty(statusStr) && Enum.TryParse<WorkflowStepStatus>(statusStr, true, out var st))
{
step.Status = st;
}
steps.Add(step);
}
wf.Steps = steps;
}
return wf;
}
private static ISpecialMessage? DeserializeXyzLoadCard(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null) return null;
var card = new AI.Models.SpecialMessages.XyzLoadCardMessage();
if (TryGet(dict, "id", out var idObj))
card.Id = idObj?.ToString() ?? card.Id;
card.FilePath = GetString(dict, "filePath") ?? string.Empty;
card.MatchButtonLabel = GetString(dict, "matchButtonLabel") ?? "确认匹配";
var phaseInt = GetInt(dict, "phase") ?? 0;
card.Phase = (AI.Models.SpecialMessages.XyzLoadPhase)phaseInt;
// 还原表格预览
if (TryGet(dict, "tablePreview", out var tableObj) && tableObj is Dictionary<object, object> tableDict)
{
var table = new AI.Models.SpecialMessages.TableDataMessage();
table.Title = GetString(tableDict, "title") ?? "数据预览";
table.TotalRowCount = GetInt(tableDict, "totalRowCount") ?? 0;
table.MaxPreviewRows = GetInt(tableDict, "maxPreviewRows") ?? 50;
var colNames = new List<string>();
if (TryGet(tableDict, "columnNames", out var cnObj) && cnObj is IEnumerable<object> cnList)
foreach (var c in cnList)
if (c?.ToString() is string s) colNames.Add(s);
var rows = new List<IEnumerable<string>>();
if (TryGet(tableDict, "rows", out var rowsObj) && rowsObj is IEnumerable<object> rowsList)
foreach (var rowObj in rowsList)
if (rowObj is IEnumerable<object> cells)
rows.Add(cells.Select(c => c?.ToString() ?? string.Empty).ToList());
table.SetData(colNames, rows, table.TotalRowCount > 0 ? table.TotalRowCount : rows.Count);
card.TablePreview = table;
}
// 还原列头匹配字段
if (TryGet(dict, "columnMatchFields", out var fieldsObj) && fieldsObj is IEnumerable<object> fieldsList)
{
var definitionTitle = GetString(dict, "columnMatchDefinitionTitle") ?? "列头匹配";
var definitionTarget = GetString(dict, "columnMatchDefinitionSubmitTarget") ?? "GriddingModuleMatchColumns";
var formFields = new List<AI.Models.Form.FormField>();
foreach (var fe in fieldsList)
{
if (fe is not Dictionary<object, object> fd) continue;
var fid = GetString(fd, "id") ?? string.Empty;
var label = GetString(fd, "label") ?? fid;
var currentValue = GetString(fd, "currentValue") ?? string.Empty;
var options = new List<string>();
if (TryGet(fd, "options", out var optObj) && optObj is IEnumerable<object> optList)
foreach (var o in optList)
if (o?.ToString() is string s) options.Add(s);
formFields.Add(new AI.Models.Form.FormField
{
Id = fid,
Label = label,
Type = AI.Models.Form.FormFieldType.Choice,
Options = options,
Required = true,
});
card.ColumnMatchFields.Add(new AI.Models.Form.FormFieldEntry
{
Id = fid,
Label = label,
Type = AI.Models.Form.FormFieldType.Choice,
Options = options,
Required = true,
CurrentValue = currentValue,
});
}
if (formFields.Count > 0)
{
card.ColumnMatchDefinition = new AI.Models.Form.FormDefinition
{
Id = "gridding-match-columns",
Title = definitionTitle,
SubmitTarget = definitionTarget,
SubmitLabel = "确认匹配",
Fields = formFields,
};
}
}
return card;
}
private static ISpecialMessage? DeserializeGriddingParamCard(string yaml)
{
var dict = FromYaml(yaml);
if (dict == null) return null;
var card = new AI.Models.SpecialMessages.GriddingParamCardMessage();
if (TryGet(dict, "id", out var idObj))
card.Id = idObj?.ToString() ?? card.Id;
var phaseInt = GetInt(dict, "phase") ?? 0;
card.Phase = (AI.Models.SpecialMessages.GriddingParamCardPhase)phaseInt;
card.StatusMessage = GetString(dict, "statusMessage") ?? string.Empty;
card.GenerateButtonLabel = GetString(dict, "generateButtonLabel") ?? "生成";
if (TryGet(dict, "items", out var itemsObj) && itemsObj is IEnumerable<object> itemsList)
{
foreach (var item in itemsList)
{
if (item is not Dictionary<object, object> id) continue;
var name = GetString(id, "name") ?? string.Empty;
var valueText = GetString(id, "valueText") ?? string.Empty;
var description = GetString(id, "description") ?? string.Empty;
var fieldTypeStr = GetString(id, "fieldType");
var fieldType = AI.Models.SpecialMessages.ParameterSetFieldType.Text;
if (!string.IsNullOrEmpty(fieldTypeStr) &&
Enum.TryParse<AI.Models.SpecialMessages.ParameterSetFieldType>(fieldTypeStr, true, out var ft))
fieldType = ft;
var options = new List<string>();
if (TryGet(id, "options", out var optObj) && optObj is IEnumerable<object> optList)
foreach (var o in optList)
if (o?.ToString() is string s) options.Add(s);
card.Items.Add(new AI.Models.SpecialMessages.ParameterSetItem
{
Name = name,
ValueText = valueText,
Description = description,
FieldType = fieldType,
Options = options
});
}
}
return card;
}
private static readonly IDeserializer YamlDeserializer = new DeserializerBuilder()
.WithNamingConvention(CamelCaseNamingConvention.Instance)
.IgnoreUnmatchedProperties()
.Build();
private static Dictionary<object, object>? FromYaml(string yaml)
{
try
{
return YamlDeserializer.Deserialize<Dictionary<object, object>>(yaml);
}
catch (YamlException ex)
{
System.Diagnostics.Debug.WriteLine($"[SpecialMessageDeserializer] YAML 解析失败: {ex.Message}");
return null;
}
}
private static string? GetString(Dictionary<object, object> dict, string key)
{
if (!TryGet(dict, key, out var v))
{
return null;
}
return v?.ToString();
}
private static int? GetInt(Dictionary<object, object> dict, string key)
{
if (!TryGet(dict, key, out var v))
{
return null;
}
if (v is int i)
{
return i;
}
if (v is long l)
{
return (int)l;
}
if (v != null && int.TryParse(v.ToString(), out var n))
{
return n;
}
return null;
}
private static bool TryGet(Dictionary<object, object> dict, string key, out object? value)
{
value = null;
var keyEq = StringComparer.OrdinalIgnoreCase;
foreach (var kv in dict)
{
if (keyEq.Equals(kv.Key?.ToString(), key))
{
value = kv.Value;
return true;
}
}
return false;
}
}
}