Melhore o raciocínio multi-salto em LLMs aprendendo com o rico feedback humano

Melhore o raciocínio multi-salto em LLMs aprendendo com o rico feedback humano

Os recentes modelos de linguagem ampla (LLMs) permitiram um tremendo progresso na compreensão da linguagem natural. No entanto, eles tendem a gerar explicações confiantes, mas sem sentido, o que representa um obstáculo significativo para estabelecer a confiança dos usuários. Neste post, mostramos como incorporar o feedback humano nas cadeias de raciocínio incorretas para o raciocínio multi-salto para melhorar o desempenho nessas tarefas. Em vez de coletar as cadeias de raciocínio do zero, perguntando aos humanos, aprendemos com o feedback humano rico em cadeias de raciocínio geradas por modelos usando as habilidades de solicitação dos LLMs. Coletamos dois desses conjuntos de dados de feedback humano na forma de (correção, explicação, tipo de erro) para os conjuntos de dados StrategyQA e Sports Understanding e avaliamos vários algoritmos comuns para aprender com esse feedback. Nossos métodos propostos funcionam de forma competitiva em relação ao prompt de cadeia de pensamento usando o Flan-T5 básico, e o nosso é melhor em julgar a exatidão de sua própria resposta.

Visão geral da solução

Com o surgimento de grandes modelos de linguagem, o campo viu um tremendo progresso em vários benchmarks de processamento de linguagem natural (NLP). Entre eles, o progresso foi impressionante em tarefas relativamente mais simples, como contexto curto ou resposta a perguntas factuais, em comparação com tarefas mais difíceis que exigem raciocínio, como respostas a perguntas com vários saltos. O desempenho de certas tarefas usando LLMs pode ser semelhante à adivinhação aleatória em escalas menores, mas melhora significativamente em escalas maiores. Apesar disso, as habilidades de solicitação dos LLMs têm o potencial de fornecer alguns fatos relevantes necessários para responder à pergunta.

No entanto, esses modelos podem não gerar cadeias de raciocínio ou explicações corretas de forma confiável. Essas explicações confiantes, mas sem sentido, são ainda mais prevalentes quando os LLMs são treinados usando o Aprendizado por Reforço do Feedback Humano (RLHF), onde o hacking de recompensa pode ocorrer.

Motivados por isso, tentamos abordar a seguinte questão de pesquisa: podemos melhorar o raciocínio dos LLMs aprendendo com o feedback humano sobre as cadeias de raciocínio geradas por modelos? A figura a seguir fornece uma visão geral de nossa abordagem: primeiro solicitamos ao modelo que gere cadeias de raciocínio para questões com vários saltos, depois coletamos diversos comentários humanos sobre essas cadeias para diagnóstico e propomos algoritmos de treinamento para aprender com os dados coletados.

Melhore o raciocínio multi-hop em LLMs aprendendo com o rico feedback humano PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

Coletamos feedback humano diversificado em dois conjuntos de dados de raciocínio multi-hop, StrategyQA e Sports Understanding do BigBench. Para cada pergunta e cadeia de raciocínio gerada pelo modelo, coletamos a cadeia de raciocínio correta, o tipo de erro na cadeia de raciocínio gerada pelo modelo e uma descrição (em linguagem natural) de por que esse erro é apresentado na cadeia de raciocínio fornecida. O conjunto de dados final contém feedback para 1,565 amostras de StrategyQA e 796 exemplos para Sports Understanding.

Propomos vários algoritmos de treinamento para aprender com o feedback coletado. Em primeiro lugar, propomos uma variante de autoconsistência no prompting de cadeia de pensamento, considerando uma variante ponderada que pode ser aprendida com o feedback. Em segundo lugar, propomos o refinamento iterativo, no qual refinamos iterativamente a cadeia de raciocínio gerada pelo modelo até que esteja correta. Demonstramos empiricamente nos dois conjuntos de dados que o ajuste fino de um LLM, ou seja, Flan-T5 usando os algoritmos propostos, tem um desempenho comparável à linha de base de aprendizado no contexto. Mais importante, mostramos que o modelo ajustado é melhor em julgar se sua própria resposta está correta em comparação com o modelo Flan-T5 básico.

A coleta de dados

Nesta seção, descrevemos os detalhes do feedback que coletamos e o protocolo de anotação seguido durante a coleta de dados. Coletamos feedback para gerações de modelos com base em dois conjuntos de dados baseados em raciocínio: StrategyQA e Sports Understanding from BigBench. Usamos GPT-J para gerar a resposta para StrategyQA e Flan-T5 para gerar a resposta para o conjunto de dados Sports Understanding. Em cada caso, o modelo recebeu k exemplos no contexto contendo pergunta, resposta e explicação, seguidos pela pergunta de teste.

A figura a seguir mostra a interface que usamos. Os anotadores recebem a pergunta, a resposta gerada pelo modelo e a explicação dividida em etapas.

Melhore o raciocínio multi-hop em LLMs aprendendo com o rico feedback humano PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.

Para cada pergunta, coletamos o seguinte feedback:

  • Subquestões – Os anotadores decompõem a pergunta original em subperguntas mais simples necessárias para responder à pergunta original. Esta tarefa foi adicionada após um piloto onde descobrimos que adicionar esta tarefa ajuda a preparar os anotadores e melhorar a qualidade do restante das tarefas.
  • Correção – Os anotadores recebem uma caixa de texto de formato livre pré-preenchida com a resposta e a explicação geradas pelo modelo e são solicitadas a editá-la para obter a resposta e a explicação corretas.
  • Tipo de erro – Entre os tipos de erro mais comuns que encontramos nas gerações do modelo (erro factual, fatos ausentes, fatos irrelevantes e inconsistência lógica), os anotadores foram solicitados a escolher um ou mais dos tipos de erro que se aplicam à resposta e explicação fornecidas.
  • Descrição de erro – Os anotadores foram instruídos a não apenas classificar os erros, mas também fornecer uma justificativa abrangente para sua categorização, incluindo a identificação da etapa exata em que ocorreu o erro e como isso se aplica à resposta e à explicação fornecida.

Usamos Amazon SageMaker Ground Truth Plus em nossa coleta de dados. A coleta de dados ocorreu em várias rodadas. Primeiro, conduzimos dois pequenos pilotos de 30 e 200 exemplos, respectivamente, após os quais a equipe de anotadores recebeu feedback detalhado sobre a anotação. Em seguida, realizamos a coleta de dados em dois lotes para StrategyQA e em um lote para Sports Understanding, fornecendo feedback periódico durante todo o processo - um total de 10 anotadores trabalharam na tarefa durante um período de cerca de 1 mês.

Reunimos feedback sobre um total de 1,565 exemplos para StrategyQA e 796 exemplos para Sports Understanding. A tabela a seguir ilustra a porcentagem de exemplos sem erros na geração do modelo e a proporção de exemplos que continham um tipo de erro específico. Vale a pena notar que alguns exemplos podem ter mais de um tipo de erro.

Tipo de Erro Controle de qualidade da estratégia Compreensão Esportiva
nenhum 17.6% 31.28%
Erro factual 27.6% 38.1%
Fatos ausentes 50.4% 46.1%
Fatos Irrelevantes 14.6% 3.9%
Inconsistência Lógica 11.2% 5.2%

Algoritmos de aprendizagem

Para cada pergunta q, e resposta e explicação geradas pelo modelo m, coletamos o seguinte feedback: resposta correta e explicação c, tipo de erro presente em m (denotado por t) e descrição do erro d, conforme descrito na seção anterior.

Usamos os seguintes métodos:

  • Aprendizagem multitarefa – Uma linha de base simples para aprender com os diversos comentários disponíveis é tratar cada um deles como uma tarefa separada. Mais concretamente, ajustamos o Flan-T5 (texto para texto) com o objetivo maximizar p(c|q) + p(t|q, m) + p(d|q, m). Para cada termo no objetivo, usamos uma instrução separada apropriada para a tarefa (por exemplo, “Prever erro na resposta dada”). Também convertemos a variável categórica t em uma frase de linguagem natural. Durante a inferência, usamos a instrução para o termo p(c|q) (“Prever a resposta correta para a pergunta dada”) para gerar a resposta para a pergunta do teste.
  • Autoconsistência ponderada – Motivados pelo sucesso da autoconsistência na sugestão de cadeia de pensamento, propomos uma variante ponderada dela. Em vez de tratar cada explicação amostrada do modelo como correta e considerar o voto agregado, primeiro consideramos se a explicação está correta e depois agregamos de acordo. Primeiro ajustamos o Flan-T5 com o mesmo objetivo do aprendizado multitarefa. Durante a inferência, dada uma pergunta de teste q, amostramos várias respostas possíveis com a instrução para p(c|q)): a1, a2, .., an. Para cada resposta amostrada ai, usamos a instrução para o termo p(t|q,m) (“Prever erro na resposta dada”) para identificar se contém erro ti = argmáx p(t|q, a_i). cada resposta ai é atribuído um peso de 1 se estiver correto, caso contrário, é atribuído um peso menor que 1 (hiperparâmetro ajustável). A resposta final é obtida considerando uma votação ponderada sobre todas as respostas a1 para an.
  • refinamento iterativo – Nos métodos propostos anteriormente, o modelo gera diretamente a resposta correta c condicionado na pergunta q. Aqui propomos refinar a resposta gerada pelo modelo m para obter a resposta correta para uma determinada pergunta. Mais especificamente, primeiro ajustamos o Flan-T5 (texto para texto com a objetiva) com maximizar p(t; c|q, m), Onde ; denota a concatenação (tipo de erro t seguido da resposta correta c). Uma maneira de visualizar esse objetivo é que o modelo seja primeiro treinado para identificar o erro em determinada geração me, em seguida, remover esse erro para obter a resposta correta c. Durante a inferência, podemos usar o modelo iterativamente até que ele gere a resposta correta – dada uma pergunta de teste q, primeiro obtemos a geração inicial do modelo m (usando Flan-T5 pré-treinado). Em seguida, geramos iterativamente o tipo de erro ti e possível resposta correta ci até ti = sem erro (na prática, definimos um número máximo de iterações para um hiperparâmetro), caso em que a resposta correta final será ci-1 (obtido de p(ti; ci | q, ci-1)).

Resultados

Para ambos os conjuntos de dados, comparamos todos os algoritmos de aprendizado propostos com a linha de base de aprendizado no contexto. Todos os modelos são avaliados no conjunto de desenvolvimento de StrategyQA e Sports Understanding. A tabela a seguir mostra os resultados.

Forma Controle de qualidade da estratégia Compreensão Esportiva
Aprendizagem em contexto de cadeia de pensamento de 5 tiros Flan-T4 67.39 ± 2.6% 58.5%
Aprendizagem multitarefa 66.22 ± 0.7% 54.3 ± 2.1%
Autoconsistência ponderada 61.13 ± 1.5% 51.3 ± 1.9%
Refinamento Iterativo 61.85 ± 3.3% 57.0 ± 2.5%

Conforme observado, alguns métodos têm desempenho comparável à linha de base de aprendizado no contexto (multitarefa para StrategyQA e refinamento iterativo para Compreensão de esportes), o que demonstra o potencial de coletar feedback contínuo de humanos sobre saídas de modelo e usá-lo para melhorar os modelos de linguagem. Isso é diferente de trabalhos recentes como RLHF, onde o feedback é limitado a categórico e geralmente binário.

Conforme mostrado na tabela a seguir, investigamos como modelos adaptados com feedback humano sobre erros de raciocínio podem ajudar a melhorar a calibração ou a consciência de explicações erradas com confiança. Isso é avaliado solicitando ao modelo que preveja se sua geração contém algum erro.

Forma Correção de Erros Controle de qualidade da estratégia
Aprendizagem em contexto de cadeia de pensamento de 5 tiros Flan-T4 Não 30.17%
Modelo ajustado para multitarefa Sim 73.98%

Mais detalhadamente, solicitamos ao modelo de linguagem sua própria resposta gerada e cadeia de raciocínio (para a qual coletamos feedback) e, em seguida, solicitamos novamente que preveja o erro na geração. Usamos a instrução apropriada para a tarefa (“Identifique o erro na resposta”). O modelo é pontuado corretamente se prever “nenhum erro” ou “correto” na geração, se os anotadores rotularem o exemplo como sem erro ou se prever qualquer um dos tipos de erro na geração (junto com “incorreto” ou “ errado”) quando os anotadores rotularam como tendo um erro. Observe que não avaliamos a capacidade do modelo de identificar corretamente o tipo de erro, mas sim a presença de um erro. A avaliação é feita em um conjunto de 173 exemplos adicionais do conjunto de desenvolvimento StrategyQA que foram coletados, que não são vistos durante o ajuste fino. Quatro exemplos desses são reservados para solicitar o modelo de linguagem (primeira linha na tabela anterior).

Observe que não mostramos o resultado da linha de base 0-shot porque o modelo não é capaz de gerar respostas úteis. Observamos que o uso de feedback humano para correção de erros em cadeias de raciocínio pode melhorar a previsão do modelo de erros ou não, o que pode melhorar a consciência ou calibração das explicações erradas.

Conclusão

Nesta postagem, mostramos como selecionar conjuntos de dados de feedback humano com correções de erros refinadas, que é uma forma alternativa de melhorar as habilidades de raciocínio dos LLMs. Os resultados experimentais corroboram que o feedback humano sobre erros de raciocínio pode melhorar o desempenho e a calibração em questões desafiadoras de vários saltos.

Se você está procurando feedback humano para melhorar seus modelos de linguagem grandes, visite Rotulagem de dados do Amazon SageMaker e o console do Ground Truth Plus.


Sobre os autores

Melhore o raciocínio multi-hop em LLMs aprendendo com o rico feedback humano PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.Erran Li é o gerente de ciência aplicada nos serviços humain-in-the-loop, AWS AI, Amazon. Seus interesses de pesquisa são aprendizado profundo em 3D e aprendizado de representação de visão e linguagem. Anteriormente, ele foi cientista sênior da Alexa AI, chefe de aprendizado de máquina da Scale AI e cientista-chefe da Pony.ai. Antes disso, ele estava com a equipe de percepção da Uber ATG e a equipe de plataforma de aprendizado de máquina da Uber trabalhando em aprendizado de máquina para direção autônoma, sistemas de aprendizado de máquina e iniciativas estratégicas de IA. Ele começou sua carreira no Bell Labs e foi professor adjunto na Universidade de Columbia. Ele co-lecionou tutoriais no ICML'17 e ICCV'19, e co-organizou vários workshops no NeurIPS, ICML, CVPR, ICCV sobre aprendizado de máquina para direção autônoma, visão 3D e robótica, sistemas de aprendizado de máquina e aprendizado de máquina adversário. Ele tem um PhD em ciência da computação na Cornell University. É Fellow da ACM e Fellow do IEEE.

Melhore o raciocínio multi-hop em LLMs aprendendo com o rico feedback humano PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.Nitish Joshi foi estagiário de ciência aplicada na AWS AI, Amazon. Ele é aluno de doutorado em ciência da computação no Courant Institute of Mathematical Sciences da Universidade de Nova York, orientado pelo Prof. He He. Ele trabalha com aprendizado de máquina e processamento de linguagem natural e foi afiliado ao grupo de pesquisa Machine Learning for Language (ML2). Ele estava amplamente interessado em compreensão de linguagem robusta: tanto na construção de modelos que são robustos a mudanças de distribuição (por exemplo, através do aumento de dados humano-in-the-loop) quanto em projetar melhores maneiras de avaliar/medir a robustez dos modelos. Ele também está curioso sobre os desenvolvimentos recentes na aprendizagem em contexto e na compreensão de como ela funciona.

Melhore o raciocínio multi-hop em LLMs aprendendo com o rico feedback humano PlatoBlockchain Data Intelligence. Pesquisa vertical. Ai.Kumar Chellapilla é gerente geral e diretor da Amazon Web Services e lidera o desenvolvimento de serviços de ML/AI, como sistemas human-in-loop, AI DevOps, Geospatial ML e desenvolvimento de ADAS/Autonomous Vehicle. Antes da AWS, Kumar foi diretor de engenharia da Uber ATG e Lyft Level 5 e liderou equipes usando aprendizado de máquina para desenvolver recursos de direção autônoma, como percepção e mapeamento. Ele também trabalhou na aplicação de técnicas de aprendizado de máquina para melhorar a pesquisa, recomendações e produtos de publicidade no LinkedIn, Twitter, Bing e Microsoft Research.

Carimbo de hora:

Mais de Aprendizado de máquina da AWS