Améliorez le raisonnement multi-sauts dans les LLM en apprenant à partir de commentaires humains riches

Améliorez le raisonnement multi-sauts dans les LLM en apprenant à partir de commentaires humains riches

Les grands modèles de langage (LLM) récents ont permis d'énormes progrès dans la compréhension du langage naturel. Cependant, ils sont susceptibles de générer des explications confiantes mais absurdes, ce qui constitue un obstacle important à l'établissement d'une relation de confiance avec les utilisateurs. Dans cet article, nous montrons comment incorporer les commentaires humains sur les chaînes de raisonnement incorrectes pour le raisonnement multi-sauts afin d'améliorer les performances sur ces tâches. Au lieu de collecter les chaînes de raisonnement à partir de zéro en demandant aux humains, nous apprenons à la place de riches commentaires humains sur les chaînes de raisonnement générées par le modèle en utilisant les capacités d'incitation des LLM. Nous recueillons deux de ces ensembles de données de rétroaction humaine sous la forme de (correction, explication, type d'erreur) pour les ensembles de données StrategyQA et Sports Understanding, et évaluons plusieurs algorithmes courants pour apprendre de ces commentaires. Nos méthodes proposées fonctionnent de manière compétitive par rapport à l'incitation à la chaîne de pensée en utilisant la base Flan-T5, et la nôtre est meilleure pour juger de l'exactitude de sa propre réponse.

Vue d'ensemble de la solution

Avec l'apparition de grands modèles de langage, le domaine a connu d'énormes progrès sur diverses références de traitement du langage naturel (TAL). Parmi eux, les progrès ont été frappants sur des tâches relativement plus simples telles que le contexte court ou la réponse à des questions factuelles, par rapport à des tâches plus difficiles qui nécessitent un raisonnement telles que la réponse à des questions à sauts multiples. La performance de certaines tâches utilisant des LLM peut être similaire à une estimation aléatoire à plus petite échelle, mais s'améliore considérablement à plus grande échelle. Malgré cela, les capacités d'incitation des LLM ont le potentiel de fournir certains faits pertinents nécessaires pour répondre à la question.

Cependant, ces modèles peuvent ne pas générer de manière fiable des chaînes de raisonnement ou des explications correctes. Ces explications confiantes mais absurdes sont encore plus répandues lorsque les LLM sont formés à l'aide de l'apprentissage par renforcement à partir de la rétroaction humaine (RLHF), où le piratage des récompenses peut se produire.

Motivés par cela, nous essayons d'aborder la question de recherche suivante : peut-on améliorer le raisonnement des LLM en apprenant à partir de la rétroaction humaine sur les chaînes de raisonnement générées par des modèles ? La figure suivante donne un aperçu de notre approche : nous incitons d'abord le modèle à générer des chaînes de raisonnement pour les questions à sauts multiples, puis nous collectons divers retours humains sur ces chaînes pour le diagnostic et proposons des algorithmes d'entraînement pour apprendre des données collectées.

Améliorez le raisonnement multi-sauts dans les LLM en apprenant des riches retours humains de PlatoBlockchain Data Intelligence. Recherche verticale. Aï.

Nous recueillons divers commentaires humains sur deux ensembles de données de raisonnement multi-sauts, StrategyQA et Sports Understanding de BigBench. Pour chaque question et chaîne de raisonnement générée par le modèle, nous recueillons la chaîne de raisonnement correcte, le type d'erreur dans la chaîne de raisonnement générée par le modèle et une description (en langage naturel) de la raison pour laquelle cette erreur est présentée dans la chaîne de raisonnement fournie. L'ensemble de données final contient des commentaires pour 1,565 796 échantillons de StrategyQA et XNUMX exemples pour Sports Understanding.

Nous proposons plusieurs algorithmes d'entraînement pour apprendre des retours d'expérience collectés. Tout d'abord, nous proposons une variante de l'auto-cohérence dans l'incitation à la chaîne de pensée en considérant une variante pondérée de celle-ci qui peut être apprise à partir du feedback. Deuxièmement, nous proposons un raffinement itératif, où nous affinons de manière itérative la chaîne de raisonnement générée par le modèle jusqu'à ce qu'elle soit correcte. Nous démontrons empiriquement sur les deux ensembles de données que le réglage fin d'un LLM, à savoir Flan-T5 à l'aide des algorithmes proposés, fonctionne de manière comparable à la base d'apprentissage en contexte. Plus important encore, nous montrons que le modèle affiné est meilleur pour juger si sa propre réponse est correcte par rapport au modèle de base Flan-T5.

Collecte de données

Dans cette section, nous décrivons les détails des commentaires que nous avons recueillis et le protocole d'annotation suivi lors de la collecte des données. Nous avons recueilli des commentaires pour les générations de modèles sur la base de deux ensembles de données basés sur le raisonnement : StrategyQA et Sports Understanding de BigBench. Nous avons utilisé GPT-J pour générer la réponse pour StrategyQA et Flan-T5 pour générer la réponse pour l'ensemble de données Sports Understanding. Dans chaque cas, le modèle a été invité avec k exemples en contexte contenant une question, une réponse et une explication, suivis de la question de test.

La figure suivante montre l'interface que nous avons utilisée. Les annotateurs reçoivent la question, la réponse générée par le modèle et l'explication divisée en étapes.

Améliorez le raisonnement multi-sauts dans les LLM en apprenant des riches retours humains de PlatoBlockchain Data Intelligence. Recherche verticale. Aï.

Pour chaque question, nous avons recueilli les commentaires suivants :

  • Sous-questions – Les annotateurs décomposent la question originale en sous-questions plus simples requises pour répondre à la question originale. Cette tâche a été ajoutée après un pilote où nous avons constaté que l'ajout de cette tâche aide à préparer les annotateurs et à améliorer la qualité du reste des tâches.
  • Correction – Les annotateurs reçoivent une zone de texte de forme libre pré-remplie avec la réponse et l'explication générées par le modèle, et sont invités à la modifier pour obtenir la réponse et l'explication correctes.
  • Type d'erreur – Parmi les types d'erreurs les plus courants que nous avons trouvés dans les générations de modèles (erreur factuelle, faits manquants, faits non pertinents et incohérence logique), les annotateurs ont été invités à choisir un ou plusieurs des types d'erreur qui s'appliquent à la réponse et à l'explication données.
  • Erreur de description – Les annotateurs ont été chargés non seulement de classer les erreurs, mais également de fournir une justification complète de leur catégorisation, notamment en indiquant l'étape exacte où l'erreur s'est produite et comment elle s'applique à la réponse et à l'explication fournies.

Nous avons utilisé Amazon SageMaker Vérité au sol Plus dans notre collecte de données. La collecte de données s'est déroulée en plusieurs cycles. Nous avons d'abord mené deux petits pilotes de 30 exemples et 200 exemples, respectivement, après quoi l'équipe d'annotateurs a reçu des commentaires détaillés sur l'annotation. Nous avons ensuite effectué la collecte de données sur deux lots pour StrategyQA et sur un lot pour Sports Understanding, en donnant des commentaires périodiques tout au long - un total de 10 annotateurs ont travaillé sur la tâche sur une période de près d'un mois.

Nous avons recueilli des commentaires sur un total de 1,565 796 exemples pour StrategyQA et XNUMX exemples pour Sports Understanding. Le tableau suivant illustre le pourcentage d'exemples sans erreur lors de la génération du modèle et la proportion d'exemples contenant un type d'erreur spécifique. Il convient de noter que certains exemples peuvent avoir plus d'un type d'erreur.

Type d'erreur StratégieQA Compréhension sportive
Aucun 17.6% 31.28%
Erreur factuelle 27.6% 38.1%
Faits manquants 50.4% 46.1%
Faits non pertinents 14.6% 3.9%
Incohérence logique 11.2% 5.2%

Algorithmes d'apprentissage

Pour chaque question q, et réponse et explication générées par le modèle m, nous avons recueilli les commentaires suivants : réponse correcte et explication c, type d'erreur présente dans m (dénoté par t) et la description de l'erreur d, comme décrit dans la section précédente.

Nous avons utilisé les méthodes suivantes :

  • Apprentissage multitâche – Une ligne de base simple pour apprendre des divers commentaires disponibles est de traiter chacun d'eux comme une tâche distincte. Plus concrètement, nous peaufinons Flan-T5 (text to text) avec pour objectif maximisent p(c|q) + p(t|q, m) + p(d|q, m). Pour chaque terme de l'objectif, nous utilisons une instruction distincte appropriée à la tâche (par exemple, « Prédire l'erreur dans la réponse donnée »). Nous convertissons également la variable catégorielle t dans une phrase en langage naturel. Lors de l'inférence, nous utilisons l'instruction pour le terme p(c|q) ("Prédire la bonne réponse pour la question donnée") pour générer la réponse à la question du test.
  • Autocohérence pondérée – Motivés par le succès de l'auto-cohérence dans l'incitation à la chaîne de pensée, nous en proposons une variante pondérée. Au lieu de traiter chaque explication échantillonnée du modèle comme correcte et de considérer le vote global, nous examinons d'abord si l'explication est correcte, puis agrégeons en conséquence. Nous affinons d'abord Flan-T5 avec le même objectif que dans l'apprentissage multitâche. Pendant l'inférence, étant donné une question de test q, nous échantillonnons plusieurs réponses possibles avec l'instruction pour p(c|q)): a1, a2, .., an. Pour chaque réponse échantillonnée ai, on utilise l'instruction pour le terme p(t|q,m) ("Prédire l'erreur dans la réponse donnée") pour identifier si elle contient une erreur ti = argmax p(t|q, a_i). Chaque réponse ai se voit attribuer un poids de 1 s'il est correct, sinon un poids inférieur à 1 lui est attribué (hyperparamètre réglable). La réponse finale est obtenue en considérant un vote pondéré sur toutes les réponses a1 à an.
  • Raffinement itératif – Dans les méthodes proposées précédemment, le modèle génère directement la bonne réponse c conditionné à la question q. Nous proposons ici d'affiner la réponse générée par le modèle m pour obtenir la bonne réponse à une question donnée. Plus précisément, nous affinons d'abord Flan-T5 (texte à texte avec l'objectif) avec maximiser p(t; c|q, m), Où ; désigne la concaténation (type d'erreur t suivi de la bonne réponse c). Une façon de voir cet objectif est que le modèle est d'abord formé pour identifier l'erreur dans une génération donnée m, puis de supprimer cette erreur pour obtenir la bonne réponse c. Pendant l'inférence, nous pouvons utiliser le modèle de manière itérative jusqu'à ce qu'il génère la bonne réponse, étant donné une question de test q, nous obtenons d'abord la génération initiale du modèle m (en utilisant Flan-T5 pré-formé). Nous générons ensuite itérativement le type d'erreur ti et la bonne réponse potentielle ci jusqu'à ti = pas d'erreur (en pratique, on fixe un nombre maximum d'itérations à un hyperparamètre), auquel cas la bonne réponse finale sera ci-1 (obtenu à partir de p(ti ; ci | q, ci-1)).

Résultats

Pour les deux ensembles de données, nous comparons tous les algorithmes d'apprentissage proposés avec la base d'apprentissage en contexte. Tous les modèles sont évalués sur l'ensemble de développement de StrategyQA et Sports Understanding. Le tableau suivant montre les résultats.

Method StratégieQA Compréhension sportive
Flan-T5 Apprentissage en contexte de la chaîne de pensée en 4 étapes 67.39 ± 2.6% 58.5%
Apprentissage multitâche 66.22 ± 0.7% 54.3 ± 2.1%
Autocohérence pondérée 61.13 ± 1.5% 51.3 ± 1.9%
Raffinement itératif 61.85 ± 3.3% 57.0 ± 2.5%

Comme observé, certaines méthodes fonctionnent de manière comparable à la base d'apprentissage en contexte (multitâche pour StrategyQA et raffinement itératif pour Sports Understanding), ce qui démontre le potentiel de recueillir des commentaires continus des humains sur les sorties du modèle et de les utiliser pour améliorer les modèles de langage. Ceci est différent des travaux récents tels que RLHF, où la rétroaction est limitée à catégorique et généralement binaire.

Comme le montre le tableau suivant, nous étudions comment des modèles adaptés avec une rétroaction humaine sur les erreurs de raisonnement peuvent aider à améliorer l'étalonnage ou la prise de conscience d'explications erronées en toute confiance. Ceci est évalué en demandant au modèle de prédire si sa génération contient des erreurs.

Method Correction d'erreur StratégieQA
Flan-T5 Apprentissage en contexte de la chaîne de pensée en 4 étapes Non 30.17%
Modèle multitâche affiné Oui 73.98%

Plus en détail, nous invitons le modèle de langage avec sa propre chaîne de réponse et de raisonnement générée (pour laquelle nous avons recueilli des commentaires), puis nous l'invitons à nouveau à prédire l'erreur de génération. Nous utilisons l'instruction appropriée pour la tâche ("Identifiez l'erreur dans la réponse"). Le modèle est noté correctement s'il prédit "aucune erreur" ou "correct" dans la génération si les annotateurs ont étiqueté l'exemple comme n'ayant aucune erreur, ou s'il prédit l'un des types d'erreur dans la génération (avec "incorrect" ou " faux ») lorsque les annotateurs l'ont étiqueté comme ayant une erreur. Notez que nous n'évaluons pas la capacité du modèle à identifier correctement le type d'erreur, mais plutôt si une erreur est présente. L'évaluation est effectuée sur un ensemble de 173 exemples supplémentaires de l'ensemble de développement StrategyQA qui ont été collectés, qui ne sont pas vus lors du réglage fin. Parmi ceux-ci, quatre exemples sont réservés à l'invite du modèle de langage (première ligne du tableau précédent).

Notez que nous ne montrons pas le résultat de la ligne de base 0-shot car le modèle est incapable de générer des réponses utiles. Nous observons que l'utilisation de la rétroaction humaine pour la correction des erreurs sur les chaînes de raisonnement peut améliorer la prédiction du modèle quant à savoir s'il fait des erreurs ou non, ce qui peut améliorer la prise de conscience ou le calibrage des mauvaises explications.

Conclusion

Dans cet article, nous avons montré comment organiser des ensembles de données de rétroaction humaine avec des corrections d'erreurs fines, ce qui est une autre façon d'améliorer les capacités de raisonnement des LLM. Les résultats expérimentaux corroborent le fait que la rétroaction humaine sur les erreurs de raisonnement peut améliorer les performances et l'étalonnage sur des questions difficiles à sauts multiples.

Si vous recherchez des commentaires humains pour améliorer vos grands modèles de langage, visitez Étiquetage des données Amazon SageMaker et la console Ground Truth Plus.


À propos des auteurs

Améliorez le raisonnement multi-sauts dans les LLM en apprenant des riches retours humains de PlatoBlockchain Data Intelligence. Recherche verticale. Aï.Erran Li est responsable des sciences appliquées chez humain-in-the-loop services, AWS AI, Amazon. Ses intérêts de recherche sont l'apprentissage profond 3D et l'apprentissage de la vision et de la représentation du langage. Auparavant, il était scientifique principal chez Alexa AI, responsable de l'apprentissage automatique chez Scale AI et scientifique en chef chez Pony.ai. Auparavant, il faisait partie de l'équipe de perception d'Uber ATG et de l'équipe de la plateforme d'apprentissage automatique d'Uber, travaillant sur l'apprentissage automatique pour la conduite autonome, les systèmes d'apprentissage automatique et les initiatives stratégiques de l'IA. Il a commencé sa carrière aux Bell Labs et a été professeur adjoint à l'Université de Columbia. Il a co-enseigné des tutoriels à ICML'17 et ICCV'19, et co-organisé plusieurs ateliers à NeurIPS, ICML, CVPR, ICCV sur l'apprentissage automatique pour la conduite autonome, la vision 3D et la robotique, les systèmes d'apprentissage automatique et l'apprentissage automatique contradictoire. Il est titulaire d'un doctorat en informatique de l'Université Cornell. Il est membre ACM et membre IEEE.

Améliorez le raisonnement multi-sauts dans les LLM en apprenant des riches retours humains de PlatoBlockchain Data Intelligence. Recherche verticale. Aï.Nitish Joshi était stagiaire en sciences appliquées chez AWS AI, Amazon. Il est doctorant en informatique au Courant Institute of Mathematical Sciences de l'Université de New York, sous la direction du professeur He He. Il travaille sur l'apprentissage automatique et le traitement du langage naturel, et il était affilié au groupe de recherche Machine Learning for Language (ML2). Il était largement intéressé par la compréhension robuste du langage : à la fois dans la construction de modèles robustes aux changements de distribution (par exemple grâce à l'augmentation des données humaines dans la boucle) et également dans la conception de meilleures façons d'évaluer/mesurer la robustesse des modèles. Il s'intéresse également aux développements récents de l'apprentissage en contexte et à la compréhension de son fonctionnement.

Améliorez le raisonnement multi-sauts dans les LLM en apprenant des riches retours humains de PlatoBlockchain Data Intelligence. Recherche verticale. Aï.Kumar Chellapilla est directeur général et directeur chez Amazon Web Services et dirige le développement de services ML/AI tels que les systèmes human-in-loop, AI DevOps, Geospatial ML et le développement ADAS/Autonomous Vehicle. Avant AWS, Kumar était directeur de l'ingénierie chez Uber ATG et Lyft Level 5 et dirigeait des équipes utilisant l'apprentissage automatique pour développer des capacités d'auto-conduite telles que la perception et la cartographie. Il a également travaillé sur l'application de techniques d'apprentissage automatique pour améliorer les produits de recherche, de recommandations et de publicité sur LinkedIn, Twitter, Bing et Microsoft Research.

Horodatage:

Plus de Apprentissage automatique AWS